mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
52
include/ck_tile/README.md
Normal file
52
include/ck_tile/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
[Back to the main page](../../README.md)
|
||||
# Composable Kernel Tile
|
||||
## concept
|
||||
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
|
||||
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
|
||||
- tile-based programming model, including tile-level api and the concept of distributed tensor.
|
||||
|
||||
`ck_tile` is independently from the old ck, located under [/include/ck_tile](/include/ck_tile). You don't need to include anything from old CK, `ck_tile` has similiar (indeed almost the same) implementations for users to build operators. We will have a transition period to pull everything from old ck into `ck_tile`, stay tuned.
|
||||
|
||||
## component
|
||||
`ck_tile` is splitted into several componenets including `core`, `host`, `ops/gemm`, `ops/fmha`... each component you only need to include a single header (e.g `#include "ck_tile/core.hpp"`, `#include "ck_tile/ops/fmha.hpp"`) then you are able to use the function/structure inside (different from old `ck`)
|
||||
|
||||
**[core]**
|
||||
`ck_tile/core` contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck.
|
||||
|
||||
`core/container`
|
||||
- array, store runtime variables with fixed length (tensor index, register buffer, etc...)
|
||||
- tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer.
|
||||
- sequence, compile time integer sequence used to build various internal structures, or to describe tile size
|
||||
- other convenient structure build on top of above 3
|
||||
|
||||
`core/numeric`
|
||||
- gpu data type like `fp16_t`, `bf16_t`, `fp8_t`... and the conversion between each other
|
||||
- constexpr integer similiar to std::integral_constant to be used as compile time integer.
|
||||
- math functions and numeric utilities
|
||||
|
||||
`core/algorithm`
|
||||
- coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old `ck` to describe how a tensor is build by several basic transform primitives like `merge`/`unmerge`/`embed` etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset.
|
||||
|
||||
`core/tensor`
|
||||
- tensor descriptor, to describe how a ND tensor
|
||||
- distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor.
|
||||
- tile level API, including `load_tile`, `store_tile`, `shuffle_tile`, `slice_tile`, etc...
|
||||
|
||||
**[host]**
|
||||
`ck_tile/host` contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need `ck_tile` to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include `ck_tile/host/kernel_launch.hpp`), unless you are writing a host executable.
|
||||
|
||||
**[ops/gemm, ops/fmha, ops/reduce...]**
|
||||
our implementation of different device operators.
|
||||
- warp, warp tile level operator
|
||||
- block, block tile level operator
|
||||
- pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations.
|
||||
- kernel, template interface for users to instantiate a particular kernel
|
||||
|
||||
**[ops/epilogue]**
|
||||
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
|
||||
|
||||
**[ref]**
|
||||
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
|
||||
|
||||
## examples
|
||||
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
|
||||
108
include/ck_tile/core.hpp
Normal file
108
include/ck_tile/core.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
|
||||
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
#include "ck_tile/core/container/static_array.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
#include "ck_tile/core/numeric/null_type.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
||||
#include "ck_tile/core/tensor/slice_tile.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/store_tile.hpp"
|
||||
#include "ck_tile/core/tensor/sweep_tile.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_coordinate.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/gemm_validation.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/mixed_prec_compute_type.hpp"
|
||||
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
18
include/ck_tile/core/README.md
Normal file
18
include/ck_tile/core/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# ck_tile/core #
|
||||
|
||||
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
|
||||
|
||||
```
|
||||
algorithm/
|
||||
coordinate transform and some other reusable algorithm
|
||||
arch/
|
||||
contains some basic device building block like mma, buffer addressing, etc...
|
||||
container/
|
||||
contains basic container data structure, array/sequence/tuple/...
|
||||
numeric/
|
||||
data type, and data type related math
|
||||
tensor/
|
||||
tensor descriptors and tile level API
|
||||
utility/
|
||||
other utility function for both host/device
|
||||
```
|
||||
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1782
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
1782
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
File diff suppressed because it is too large
Load Diff
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// pre-defined indexing adaptor used for indexing(scatter/gather)
|
||||
|
||||
// this version cache the index inside thread register(which is also prefered in real senario)
|
||||
// however it's user's responsibility that each thread only provide one indexing, which means
|
||||
// move coordinate will not change on this dim
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor_onshot_cached
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
|
||||
: cached_idx_(idx)
|
||||
{
|
||||
}
|
||||
IndexingType cached_idx_;
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = cached_idx_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
166
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
166
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess,
|
||||
bool SnakeCurved = true> // # of scalars per access in each dimension
|
||||
struct space_filling_curve
|
||||
{
|
||||
static constexpr index_t TensorSize =
|
||||
reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{});
|
||||
static_assert(0 < TensorSize,
|
||||
"space_filling_curve should be used to access a non-empty tensor");
|
||||
|
||||
static constexpr index_t nDim = TensorLengths::size();
|
||||
|
||||
using Index = multi_index<nDim>;
|
||||
|
||||
static constexpr index_t ScalarPerVector =
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
|
||||
static constexpr auto dim_access_order = DimAccessOrder{};
|
||||
static constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(ordered_access_lengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{}) / ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
|
||||
number<AccessIdx1dTail>)
|
||||
{
|
||||
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
|
||||
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
|
||||
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
|
||||
return idx_tail - idx_head;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
|
||||
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
// Do not use this function directly!
|
||||
// TODO: can refactor into generic lambda in the future
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
|
||||
*/
|
||||
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
|
||||
#else
|
||||
|
||||
constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr auto idx_1d = number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
|
||||
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
|
||||
id = res / access_strides[kdim].value;
|
||||
res -= id * access_strides[kdim].value;
|
||||
});
|
||||
|
||||
return id;
|
||||
};
|
||||
|
||||
constexpr auto id = compute_index_impl(idim);
|
||||
return number<id>{};
|
||||
};
|
||||
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
statically_indexed_array<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto idim) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, idim, 1>{}(
|
||||
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
|
||||
|
||||
forward_sweep_(idim) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate multi-dim tensor index
|
||||
auto idx_md = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
ScalarsPerAccess{};
|
||||
}();
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: return tuple of number<>, which is compile time only variable
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = _get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
370
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
Normal file
370
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file
|
||||
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
|
||||
for `BlockSize` threads in a thread block.
|
||||
* X dimension is considered contiguous in memory, so a single instruction can access
|
||||
several adjacent and properly aligned elements (vector); the access pattern along X tile
|
||||
dimension is parameterized only by the suggested vector size `VecSize`.
|
||||
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
|
||||
with a single memory access, so the actual vector size along the X dimension is
|
||||
`X0 = min(MaxVecSize, VecSize)`.
|
||||
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
|
||||
* X1 is also the number of threads per warp in X dimension, that is,
|
||||
X dimension is not split between warps, and each warp accesses X dimension entirely,
|
||||
and there is no iteration in X dimension.
|
||||
* The tuple <X0, X1> defines the X-axis access pattern.
|
||||
This part is common between the 2D distribution patterns.
|
||||
|
||||
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
|
||||
* There are 3 components in this access pattern;
|
||||
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
|
||||
* (2) number of warps per thread block,
|
||||
* (3) number of iterations to cover the entire Y axis.
|
||||
|
||||
* The raked here represents how data is partitioned across different processing granularity.
|
||||
* It represents how we are going to access the data in thread, warp, or blocked in contiguous
|
||||
region.
|
||||
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
|
||||
* in the split of Y tile dimension where the iteration happens,
|
||||
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
|
||||
* (1) after thread -> thread-raked,
|
||||
* (2) between warp and thread -> warp-raked,
|
||||
* (3) before warp -> block-raked
|
||||
|
||||
* *Thread raked*
|
||||
|
||||
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of rows accessed by a warp within a single iteration,
|
||||
compute it from the equation `Y0 * X1 == WarpSize`
|
||||
* Y2 is the number of iterations to cover the tile,
|
||||
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
|
||||
|
||||
* *Warp raked*
|
||||
|
||||
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
|
||||
`Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y2 from the equation below
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* *Block raked*
|
||||
|
||||
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y1 and Y2 from the equations below
|
||||
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
|
||||
|
||||
* *Selection*
|
||||
* When we are selecting, Thread-raked is used in element-wise operation because it is the
|
||||
* Thread-major memory order.
|
||||
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
|
||||
* Block-raked is used mostly for the reduction process, where will reduce the block in global
|
||||
* atomic level.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Enumeration describing static tile distribution patterns.
|
||||
*
|
||||
*/
|
||||
enum struct tile_distribution_pattern
|
||||
{
|
||||
/**
|
||||
* @brief Thread raked pattern.
|
||||
*
|
||||
*/
|
||||
thread_raked,
|
||||
/**
|
||||
* @brief Warp raked pattern.
|
||||
*
|
||||
*/
|
||||
warp_raked,
|
||||
/**
|
||||
* @brief Block raked pattern - aka linear.
|
||||
*
|
||||
*/
|
||||
block_raked
|
||||
};
|
||||
|
||||
struct tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Class creating 2D static tile distribution with different load/store patterns.
|
||||
*
|
||||
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
|
||||
* is contiguous and we can do vector load on this dimension.
|
||||
*
|
||||
* @tparam BlockSize Number of threads in a workgroup.
|
||||
* @tparam YPerTile The tile size of outer/leftmost dimension.
|
||||
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
|
||||
* @tparam VecSize The vector access size.
|
||||
* @tparam DistributionPattern The enumeration describing used access pattern.
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = min(warp_size, XPerTile / X1); // # of threads in X dim
|
||||
|
||||
// # of rows in Y dim accessed by single wavefront in one iteration
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps / NumWaveGroups;
|
||||
// YPerWarp = YPerTile / Y0;
|
||||
// Y2 = YPerWarp / Y1;
|
||||
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
|
||||
|
||||
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <X1, Y2>
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{}); // -> <X1, Y2>
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = min(warp_size, XPerTile / X1); // # of threads in X dim
|
||||
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
|
||||
|
||||
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <X1, Y1>
|
||||
}
|
||||
};
|
||||
|
||||
// Block raked
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = min(warp_size, XPerTile / X1); // # of threads in X dim
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
static constexpr index_t Y1 = num_warps;
|
||||
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
|
||||
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{}); // -> <X1, Y0>
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to convert enum to string
|
||||
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
|
||||
{
|
||||
switch(pattern)
|
||||
{
|
||||
case tile_distribution_pattern::thread_raked: return "thread_raked";
|
||||
case tile_distribution_pattern::warp_raked: return "warp_raked";
|
||||
case tile_distribution_pattern::block_raked: return "block_raked";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern_to_string(DistributionPattern));
|
||||
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
|
||||
PatternType::Y0,
|
||||
PatternType::Y1,
|
||||
PatternType::Y2,
|
||||
PatternType::X0,
|
||||
PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
3066
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
3066
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2947
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2947
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
124
include/ck_tile/core/arch/amd_buffer_coherence.hpp
Normal file
124
include/ck_tile/core/arch/amd_buffer_coherence.hpp
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// memory coherency bit for buffer store/load instruction
|
||||
// check ISA manual for each GFX target
|
||||
// e.g. for
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
|
||||
// page 67~68
|
||||
enum struct amd_buffer_coherence_enum
|
||||
{
|
||||
coherence_default = 0, // default value
|
||||
#if defined(__gfx12__)
|
||||
// Temporal hint
|
||||
RT = 0, // regular temporal
|
||||
NT = 1, // non temporal
|
||||
HT = 2, // high priority temporal
|
||||
LU = 3, // last use (load op)
|
||||
WB = 3, // same as HT, overrides WR in far cache (store op)
|
||||
NT_RT = 4, // non temporal for near cache, regular for far cache
|
||||
RT_NT = 5, // regular for near cache, non-temporal for far cache
|
||||
NT_HT = 6, // non temporal for near cache, high priority for far cache
|
||||
NT_WB = 7, // non temporal for near cache, WB for far cache
|
||||
// (store op, reserved for load op)
|
||||
// Scope
|
||||
CU = 0,
|
||||
SE = 8,
|
||||
DEVICE = 16,
|
||||
SYSTEM = 24,
|
||||
// Temporal Hint for CU
|
||||
CU_RT = RT | CU,
|
||||
CU_NT = NT | CU,
|
||||
CU_HT = HT | CU,
|
||||
CU_LU = LU | CU,
|
||||
CU_WB = WB | CU,
|
||||
CU_NT_RT = NT_RT | CU,
|
||||
CU_RT_NT = RT_NT | CU,
|
||||
CU_NT_HT = NT_HT | CU,
|
||||
CU_NT_WB = NT_WB | CU,
|
||||
// Temporal Hint for SE
|
||||
SE_RT = RT | SE,
|
||||
SE_NT = NT | SE,
|
||||
SE_HT = HT | SE,
|
||||
SE_LU = LU | SE,
|
||||
SE_WB = WB | SE,
|
||||
SE_NT_RT = NT_RT | SE,
|
||||
SE_RT_NT = RT_NT | SE,
|
||||
SE_NT_HT = NT_HT | SE,
|
||||
SE_NT_WB = NT_WB | SE,
|
||||
// Temporal Hint for DEVICE
|
||||
DEVICE_RT = RT | DEVICE,
|
||||
DEVICE_NT = NT | DEVICE,
|
||||
DEVICE_HT = HT | DEVICE,
|
||||
DEVICE_LU = LU | DEVICE,
|
||||
DEVICE_WB = WB | DEVICE,
|
||||
DEVICE_NT_RT = NT_RT | DEVICE,
|
||||
DEVICE_RT_NT = RT_NT | DEVICE,
|
||||
DEVICE_NT_HT = NT_HT | DEVICE,
|
||||
DEVICE_NT_WB = NT_WB | DEVICE,
|
||||
// Temporal Hint for SYSTEM
|
||||
SYSTEM_RT = RT | SYSTEM,
|
||||
SYSTEM_NT = NT | SYSTEM,
|
||||
SYSTEM_HT = HT | SYSTEM,
|
||||
SYSTEM_LU = LU | SYSTEM,
|
||||
SYSTEM_WB = WB | SYSTEM,
|
||||
SYSTEM_NT_RT = NT_RT | SYSTEM,
|
||||
SYSTEM_RT_NT = RT_NT | SYSTEM,
|
||||
SYSTEM_NT_HT = NT_HT | SYSTEM,
|
||||
SYSTEM_NT_WB = NT_WB | SYSTEM,
|
||||
|
||||
// GFX942 and GFX950 compatiblity
|
||||
GROUP_NT0 = CU_RT,
|
||||
GROUP_NT1 = CU_NT,
|
||||
DEVICE_NT0 = DEVICE_RT,
|
||||
DEVICE_NT1 = DEVICE_NT,
|
||||
SYSTEM_NT0 = SYSTEM_RT,
|
||||
SYSTEM_NT1 = SYSTEM_NT,
|
||||
// Other archs compatiblity
|
||||
glc = DEVICE_NT,
|
||||
slc = SYSTEM_NT,
|
||||
glc_slc = DEVICE_NT | SYSTEM_NT,
|
||||
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
#elif defined(__gfx942__) || defined(__gfx950__)
|
||||
|
||||
WAVE = 0,
|
||||
GROUP = 1,
|
||||
DEVICE = 16,
|
||||
SYSTEM = 17,
|
||||
NT0 = 0,
|
||||
NT1 = 2,
|
||||
|
||||
WAVE_NT0 = NT0 | WAVE,
|
||||
WAVE_NT1 = NT1 | WAVE,
|
||||
GROUP_NT0 = NT0 | GROUP,
|
||||
GROUP_NT1 = NT1 | GROUP,
|
||||
DEVICE_NT0 = NT0 | DEVICE,
|
||||
DEVICE_NT1 = NT1 | DEVICE,
|
||||
SYSTEM_NT0 = NT0 | SYSTEM,
|
||||
SYSTEM_NT1 = NT1 | SYSTEM,
|
||||
|
||||
// Other archs compatiblity
|
||||
glc = DEVICE_NT1,
|
||||
slc = SYSTEM_NT1,
|
||||
glc_slc = DEVICE_NT1 | SYSTEM_NT1,
|
||||
#else
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
|
||||
// Other archs compatiblity
|
||||
DEVICE_NT0 = 0,
|
||||
SYSTEM_NT0 = 0,
|
||||
DEVICE_NT1 = glc,
|
||||
SYSTEM_NT1 = slc,
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, index_t LaneGroupSize = 16, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
|
||||
"LaneGroupSize must be 16, 32, or 64");
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t LaneGroupSize,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1216
include/ck_tile/core/arch/arch.hpp
Normal file
1216
include/ck_tile/core/arch/arch.hpp
Normal file
File diff suppressed because it is too large
Load Diff
529
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
529
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,529 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
|
||||
{
|
||||
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
{
|
||||
bf16x4_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf16_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf16_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
|
||||
{
|
||||
fp8x8_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<fp8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<fp8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<fp8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<fp8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
|
||||
{
|
||||
bf8x4_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
{
|
||||
bf8x8_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<bf8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<bf8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<bf8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<bf8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast<bf16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32BF162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf16x2_t* bf162_a;
|
||||
};
|
||||
|
||||
union U32BF162
|
||||
{
|
||||
uint32_t u32;
|
||||
bf16x2_t bf162;
|
||||
};
|
||||
|
||||
U32BF162_ADDR dword_addr;
|
||||
U32BF162 cur_v;
|
||||
U32BF162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.bf162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
|
||||
{
|
||||
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
|
||||
union U64BF164_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf16x4_t* bf164_a;
|
||||
};
|
||||
|
||||
// Union to treat the data as either bf16x4_t or 64-bit integer
|
||||
union U64BF164
|
||||
{
|
||||
uint64_t u64;
|
||||
bf16x4_t bf164;
|
||||
};
|
||||
|
||||
U64BF164_ADDR addr;
|
||||
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
|
||||
|
||||
// First read (non-atomic) of the old value
|
||||
U64BF164 cur_v;
|
||||
cur_v.u64 = *addr.u64_a;
|
||||
|
||||
U64BF164 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
do
|
||||
{
|
||||
// old 64 bits
|
||||
old_v = cur_v.u64;
|
||||
|
||||
// Add elementwise in bf16
|
||||
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt the 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
|
||||
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
|
||||
{
|
||||
union U32FP84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp8x4_t* fp84_a;
|
||||
};
|
||||
|
||||
union U32FP84
|
||||
{
|
||||
uint32_t u32;
|
||||
fp8x4_t fp84;
|
||||
};
|
||||
|
||||
U32FP84_ADDR dword_addr;
|
||||
U32FP84 cur_v;
|
||||
U32FP84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.fp84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
|
||||
{
|
||||
union U32BF84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf8x4_t* bf84_a;
|
||||
};
|
||||
|
||||
union U32BF84
|
||||
{
|
||||
uint32_t u32;
|
||||
bf8x4_t bf84;
|
||||
};
|
||||
|
||||
U32BF84_ADDR dword_addr;
|
||||
U32BF84 cur_v;
|
||||
U32BF84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.bf84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
|
||||
{
|
||||
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
|
||||
union U64FP88_ADDR
|
||||
{
|
||||
uint64_t* u64_a; // pointer to 64-bit integer
|
||||
fp8x8_t* fp88_a; // pointer to fp8x8_t
|
||||
};
|
||||
|
||||
union U64FP88
|
||||
{
|
||||
uint64_t u64;
|
||||
fp8x8_t fp88;
|
||||
};
|
||||
|
||||
U64FP88_ADDR dword_addr;
|
||||
U64FP88 cur_v;
|
||||
U64FP88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
// Point to the destination as both fp8x8_t* and uint64_t*.
|
||||
dword_addr.fp88_a = p_dst;
|
||||
// Initial read of 64 bits from memory
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each fp8 element using your add_fp8x8_t(...) routine
|
||||
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for bf8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
{
|
||||
union U64BF88_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf8x8_t* bf88_a;
|
||||
};
|
||||
|
||||
union U64BF88
|
||||
{
|
||||
uint64_t u64;
|
||||
bf8x8_t bf88;
|
||||
};
|
||||
|
||||
U64BF88_ADDR dword_addr;
|
||||
U64BF88 cur_v;
|
||||
U64BF88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
dword_addr.bf88_a = p_dst;
|
||||
// Read the original 64 bits
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each bf8 element using your add_bf8x8_t(...) routine
|
||||
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// 64-bit CAS loop
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp16x2_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32F162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* f162_a;
|
||||
};
|
||||
|
||||
union U32F162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t f162;
|
||||
};
|
||||
|
||||
U32F162_ADDR dword_addr;
|
||||
U32F162 cur_v;
|
||||
U32F162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.f162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
"The granularity of the thread buffer is unsupported on the hardware!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, x.template get_as<float>()[I2]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, x.template get_as<float>()[I3]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return atomicAdd(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x4_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
|
||||
x.template get_as<fp16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
128
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
128
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct Unsupported
|
||||
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
|
||||
*/
|
||||
struct Unsupported;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
/**
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
typename MmaOp::OpType;
|
||||
|
||||
// Captures types for inputs / outputs to mma function
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
|
||||
|
||||
// Static exec function
|
||||
{
|
||||
MmaOp::exec(
|
||||
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @class amdgcn_mma
|
||||
* @brief This is the default MmaOp policy.
|
||||
* Instances of this class are to be used as MmaOp policies.
|
||||
* Light builtin wrapper for mfma / wmma instructions. This class's job is to
|
||||
* provide a uniform interface to invoke the appropriate instruction
|
||||
* based on the template parameters provided. This interface is to bridge
|
||||
* the gap between the ck_tile API types and the native __builtin types.
|
||||
* @tparam ADataType Datatype of input A
|
||||
* @tparam BDataType Datatype of input B
|
||||
* @tparam CDataType Datatype of accumulator
|
||||
* @tparam BlockM M-dimension of mma block
|
||||
* @tparam BlockN N-dimension of mma block
|
||||
* @tparam BlockK K-dimension of mma block
|
||||
* @tparam CtrlFlags Control flags for mma operation
|
||||
* @tparam CompilerTarget The current compiler target
|
||||
* @tparam Enabler SFINAE enabler
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
struct amdgcn_mma
|
||||
{
|
||||
// The base instance is unsupported because there is no __builtin to wrap.
|
||||
using OpType = Unsupported;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
|
||||
|
||||
// Interface types for A, B, C vectors types
|
||||
using AVecType = ext_vector_t<ADataType, 1>;
|
||||
using BVecType = ext_vector_t<BDataType, 1>;
|
||||
using CVecType = ext_vector_t<CDataType, 1>;
|
||||
|
||||
// Layout constants - default to 0
|
||||
static constexpr index_t kAMBlock = 0;
|
||||
static constexpr index_t kBNBlock = 0;
|
||||
|
||||
static constexpr index_t kAMLane = 0;
|
||||
static constexpr index_t kBNLane = 0;
|
||||
static constexpr index_t kABKLane = 0;
|
||||
static constexpr index_t kABKPerLane = 0;
|
||||
|
||||
static constexpr index_t kCMLane = 0;
|
||||
static constexpr index_t kCNLane = 0;
|
||||
static constexpr index_t kCM0PerLane = 0;
|
||||
static constexpr index_t kCM1PerLane = 0;
|
||||
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
CK_TILE_DEVICE static CVecType const&
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
ignore(regsA, regsB);
|
||||
return regsC; // No-op, just return C
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "sparse/sparse.hpp"
|
||||
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include the architecture-specific MFMA implementations and traits
|
||||
#include "mfma_gfx9.hpp"
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_selector.hpp"
|
||||
#include "mfma_transforms.hpp"
|
||||
168
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
168
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct DefaultMmaCtrlFlags
|
||||
* @brief Default MFMA flags, no broadcasting or rotation of inputs
|
||||
*/
|
||||
struct DefaultMfmaCtrlFlags
|
||||
{
|
||||
static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
|
||||
static constexpr uint32_t Abid = 0; // ABID flag, default 0
|
||||
static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept CtrlFlagsGfx9I
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for Gfx9 MFMA instructions
|
||||
{ CtrlFlags::Cbsz } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Abid } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Blgp } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Packed register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
195
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
195
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_gfx9.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MfmaDefaultSelector
|
||||
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported MFMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through
|
||||
implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam BlockKTest Current Block K dimension size to test
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @note Here we assume that BlockKTest is always a power-of-two integer.
|
||||
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
|
||||
* each time.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct MfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MfmaDefaultSelector
|
||||
* @brief Implements the base case for the default MFMA selector when no supported instruction is
|
||||
* found.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported MFMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp4x4 =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 4u, 4u, 4u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
using CandidateOp16x16 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported4x4 =
|
||||
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<
|
||||
IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16,
|
||||
CandidateOp16x16,
|
||||
std::conditional_t<IsSupported4x4, CandidateOp4x4, DefaultOp>>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MfmaOp
|
||||
* @brief Meta-tag for the MFMA operation. This will be used in the MmaOp policies to
|
||||
* identify the operation as an MFMA operation.
|
||||
*/
|
||||
struct MfmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_mfma
|
||||
* @brief Trait to check if MmaOp is an MFMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_mfma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_mfma
|
||||
* @brief MmaOp specialization for MFMA operations, confirming the OpType matches MfmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, MfmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_mfma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx9
|
||||
* @brief Implements the default MMA transforms for gfx9 targets
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx9
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx9 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx9;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
235
include/ck_tile/core/arch/mma/mma.hpp
Normal file
235
include/ck_tile/core/arch/mma/mma.hpp
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_traits.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "wmma/wmma.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaAccumPolicy
|
||||
* @brief Accumulation order for Mma decomposition
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major block order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major block order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
|
||||
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
|
||||
* (C: FragM x FragN).
|
||||
* @tparam ADataType Data type of input fragment A
|
||||
* @tparam BDataType Data type of input fragment B
|
||||
* @tparam CDataType Data type of input/output fragment C (accumulator)
|
||||
* @tparam FragM Mma fragment M dimension
|
||||
* @tparam FragN Mma fragment K dimension
|
||||
* @tparam FragK Mma fragment M dimension
|
||||
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
|
||||
* order)
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
|
||||
* wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
{
|
||||
|
||||
using BlockWiseMmaOp = MmaOp;
|
||||
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
|
||||
|
||||
// Block dimensions
|
||||
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
|
||||
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
|
||||
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
|
||||
|
||||
// Block counts for decomposition
|
||||
constexpr static uint32_t BlocksM = FragM / BlockM;
|
||||
constexpr static uint32_t BlocksN = FragN / BlockN;
|
||||
constexpr static uint32_t BlocksK = FragK / BlockK;
|
||||
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
|
||||
|
||||
// Vector types for packed registers in each block
|
||||
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
|
||||
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
|
||||
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
|
||||
|
||||
// Buffer types for fragments
|
||||
using ABufferType = AVecType[BlocksM][BlocksK];
|
||||
using BBufferType = BVecType[BlocksN][BlocksK];
|
||||
using CBufferType = CVecType[BlocksM][BlocksN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
|
||||
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
|
||||
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
|
||||
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
|
||||
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
|
||||
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
|
||||
|
||||
private:
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT const&>(inputBuffer);
|
||||
}
|
||||
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT&>(inputBuffer);
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Col-major" accumulation over the M-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in col-major order
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Row-major" accumulation over the N-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
|
||||
// We also have to ensure that the incoming vector fragments are converted to native vector
|
||||
// types before passing to the BlockWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
public:
|
||||
/*! @brief Forward to Mma operation with specified accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
|
||||
{
|
||||
return exec_row_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
|
||||
{
|
||||
return exec_col_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum MmaOpFamily
|
||||
* @brief Enumeration that defines mma op families and
|
||||
*/
|
||||
enum struct MmaOpFamily
|
||||
{
|
||||
UNDEFINED = 0,
|
||||
DENSE,
|
||||
SPARSE,
|
||||
SCALE,
|
||||
};
|
||||
|
||||
/**
|
||||
* @class is_ctrl_fis_mma_op_of_familylag_of_family
|
||||
* @brief Meta-function to check if MmaOp is of the specified MmaOpFamily
|
||||
* @tparam Family Control flag family
|
||||
* @tparam MmaOp amdgcn struct specialization type
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp, typename = void>
|
||||
struct is_mma_op_of_family : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_of_family
|
||||
* @brief Specialization for Family == MmaOp::OpFamily detection
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
struct is_mma_op_of_family<Family, MmaOp, std::enable_if_t<Family == MmaOp::OpFamily>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_of_family trait
|
||||
* @tparam Family Desired control flag family
|
||||
* @tparam MmaOp The amdgcn struct specialization type to check
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
static constexpr bool is_mma_op_of_family_v = is_mma_op_of_family<Family, MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
75
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
75
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MmaDefaultSelector
|
||||
* @brief Implements a default mma selector strategy for the current target architecture.
|
||||
* This is simply intended as a default selection strategy for mma instruction operations.
|
||||
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
|
||||
* select the instruction with the largest K dimension that is supported on the current target
|
||||
* architecture.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Fragment M dimension
|
||||
* @tparam FragN Fragment N dimension
|
||||
* @tparam FragK Fragment K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam Enable SFINAE enabler
|
||||
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
|
||||
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
|
||||
* correspond to the size of the individual MMA instructions being used to compute the overall in
|
||||
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
|
||||
* equal to the Block sizes.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily,
|
||||
typename Enable = void>
|
||||
// TODO c++20 requires
|
||||
struct MmaDefaultSelector
|
||||
{
|
||||
// By default, no selection is made, and we fall back to a pass-through unsupported
|
||||
// implementation. This is because we do not have any knowledge of the target architecture here.
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
/**
|
||||
* @concept MmaSelectorI
|
||||
* @brief Expresses the required members for each MmaSelector class.
|
||||
*/
|
||||
template <typename MmaSelector>
|
||||
concept MmaSelectorI = requires(MmaSelector op) {
|
||||
// Selectors should have a resulting SelectedOp type
|
||||
typename MmaSelector::SelectedOp;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma_selector.hpp"
|
||||
#include "mfma/mfma_selector.hpp"
|
||||
164
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
164
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "mfma/mfma_traits.hpp"
|
||||
#include "wmma/wmma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class is_mma_op_supported
|
||||
* @brief Trait to check if MmaOp is supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, typename = void>
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_supported : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_supported
|
||||
* @brief The MmaOp is unsupported specialization
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_supported<MmaOp,
|
||||
std::enable_if_t<std::is_same_v<typename MmaOp::OpType, Unsupported>>>
|
||||
: std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluation of is_mma_op_supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @class MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
struct MmaOpParams;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept MmaOpParamsI
|
||||
* @brief Expresses the required members for each MmaOp
|
||||
*/
|
||||
template <typename MmaOpParams>
|
||||
concept MmaOpParamsI = requires(MmaOpParams op) {
|
||||
// Capture template parameters
|
||||
typename MmaOpParams::ADataType;
|
||||
typename MmaOpParams::BDataType;
|
||||
typename MmaOpParams::CDataType;
|
||||
typename MmaOpParams::CtrlFlags;
|
||||
|
||||
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
|
||||
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam ADataType_ Data type of matrix A
|
||||
* @tparam BDataType_ Data type of matrix B
|
||||
* @tparam CDataType_ Data type of the accumulator
|
||||
* @tparam BlockM_ Size of the M dimension
|
||||
* @tparam BlockN_ Size of the N dimension
|
||||
* @tparam BlockK_ Size of the K dimension
|
||||
* @tparam CtrlFlags_ Control flags for the MMA operation
|
||||
* @tparam CompilerTarget_ The compiler target
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
uint32_t BlockM_,
|
||||
uint32_t BlockN_,
|
||||
uint32_t BlockK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockM_,
|
||||
BlockN_,
|
||||
BlockK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>>
|
||||
{
|
||||
// Capture incoming template parameters
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
static constexpr auto MmaOpFamily = OpFamily_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaOpTraits
|
||||
* @brief Reflects the template parameters and static members of a given MmaOp.
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
|
||||
struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
{
|
||||
// Capture internal MmaOp static members
|
||||
using OpType = typename MmaOp::OpType;
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
|
||||
|
||||
// Capture layout parameters
|
||||
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
|
||||
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
|
||||
static constexpr index_t kAMLane = MmaOp::kAMLane;
|
||||
static constexpr index_t kBNLane = MmaOp::kBNLane;
|
||||
static constexpr index_t kABKLane = MmaOp::kABKLane;
|
||||
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
|
||||
static constexpr index_t kCMLane = MmaOp::kCMLane;
|
||||
static constexpr index_t kCNLane = MmaOp::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
|
||||
|
||||
// Additional traits to identify the type of MmaOp at compile time
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
|
||||
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
|
||||
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
|
||||
constexpr static bool IsSupported =
|
||||
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
49
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
49
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct PassThroughTransform
|
||||
* @brief A no-op transform that passes through the input as-is.
|
||||
*/
|
||||
struct PassThroughTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
|
||||
* @tparam MmaOp The Mma operation type
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam Enable SFINAE parameter for specialization
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
|
||||
struct MmaTransformsDefaultSelector;
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
/**
|
||||
* @concept MmaTransformsI
|
||||
* @brief Expresses the interface of required members for each MmaTransforms type.
|
||||
*/
|
||||
template <typename MmaTransforms>
|
||||
concept MmaTransformsI = requires(MmaTransforms transforms) {
|
||||
// Transforms should define TransformA, TransformB, TransformC, and TransformD types
|
||||
typename MmaTransforms::ATransform;
|
||||
typename MmaTransforms::BTransform;
|
||||
typename MmaTransforms::CTransform;
|
||||
typename MmaTransforms::DTransform;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseMfmaDefaultSelector
|
||||
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
amdgcn_target_id::GFX942,
|
||||
amdgcn_target_id::GFX950)>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DefaultSparseMfmaCtrlFlags
|
||||
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
|
||||
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
|
||||
*/
|
||||
struct DefaultSparseMfmaCtrlFlags
|
||||
{
|
||||
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept SparseMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for sparse MFMA instructions
|
||||
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
|
||||
*
|
||||
* This specialization implements the SMFMA instruction for fp16_t A and B
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
std::enable_if_t<is_any_value_of(
|
||||
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 8;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum SparseCompressionIndex
|
||||
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
|
||||
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
|
||||
*/
|
||||
enum struct SparseCompressionIndex : int
|
||||
{
|
||||
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
|
||||
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
|
||||
THIRD = 2, // Uses bits [23:16]
|
||||
FOURTH = 3, // Uses bits [31:24]
|
||||
};
|
||||
|
||||
namespace sparse::detail {
|
||||
|
||||
/**
|
||||
* @struct BuiltinParams
|
||||
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
|
||||
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
|
||||
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
|
||||
* (VGPR[srcC][7..0]).
|
||||
*
|
||||
* 8-bit source data:
|
||||
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
|
||||
* (VGPR[srcC][15..0]).
|
||||
*/
|
||||
struct BuiltinParams
|
||||
{
|
||||
int UseFirstIndex; // CBSZ
|
||||
int ByteIndexToOverride; // ABID
|
||||
};
|
||||
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
BuiltinParams params;
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
params.UseFirstIndex = 1;
|
||||
params.ByteIndexToOverride = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
params.UseFirstIndex = 0;
|
||||
params.ByteIndexToOverride = static_cast<int>(Idx);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include sparse MFMA traits and architecture-specific implementations
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsSparse
|
||||
* @brief Implements the default transforms for Sparse
|
||||
*
|
||||
* For 2:4 structured sparsity with inline register metadata:
|
||||
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
|
||||
* - BTransform: Pass-through (sparse operands already formatted)
|
||||
* - CTransform: Pass-through (input accumulator)
|
||||
* - DTransform: Pass-through (output accumulator as-is)
|
||||
*/
|
||||
struct MmaDefaultTransformsSparse
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Specialization for Sparse MFMA transforms
|
||||
* Provides default transform selection for sparse operations
|
||||
*
|
||||
* @tparam MmaOp Sparse MMA operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires(is_mma_op_sparse(MmaOp))
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsSparse;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseWmmaDefaultSelector
|
||||
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseWmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension WMMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
struct DefaultSparseWmmaCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 16;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file tile_distribution_encoding_register_mapper.hpp
|
||||
* @brief Utility for register / matrix coordinate mapping from TileDistributionEncoding
|
||||
* @details Defines TileDistrEncRegMap, which takes a TileDistributionEncoding and provides
|
||||
* functions for mapping matrix fragment coordinates to register coordinates (lane, vector item) and
|
||||
* vice versa. This is only meant for tile distributions encodings that describe register mappings.
|
||||
*
|
||||
* A repeat dimension is allowed in which case multiple (lane, vector item) pairs are mapped to the
|
||||
* same matrix coordinates. The inverse map takes a "repeat index" to distinguish between them.
|
||||
*
|
||||
* print() functions are included for printing dimensions and formatted forward and backwards
|
||||
* mappings similar to the AMD Matrix Calculator.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdio.h>
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// Utility to calculate register mappings from a Tile Distribution Encoding.
|
||||
template <typename TileDistrEnc>
|
||||
struct TileDistrEncRegMap
|
||||
{
|
||||
// Make sure this is a proper Tile Distr Encoding for Lane Vector mapping.
|
||||
static_assert(TileDistrEnc::NDimR <= 1);
|
||||
static_assert(TileDistrEnc::NDimX == 2);
|
||||
static_assert(TileDistrEnc::NDimP == 1);
|
||||
|
||||
static constexpr auto ps_ys_to_xs_adaptor =
|
||||
make_static_tile_distribution(TileDistrEnc{}).get_ps_ys_to_xs_adaptor();
|
||||
|
||||
static constexpr index_t mat_major_size =
|
||||
container_reduce(typename TileDistrEnc::HsLengthss{}[number<0>{}], multiplies<>{}, 1);
|
||||
static constexpr index_t mat_minor_size =
|
||||
container_reduce(typename TileDistrEnc::HsLengthss{}[number<1>{}], multiplies<>{}, 1);
|
||||
static constexpr index_t num_repeat = [] {
|
||||
if constexpr(TileDistrEnc::NDimR > 0)
|
||||
{
|
||||
return typename TileDistrEnc::RsLengths{}[number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // Necessary to deal with empty "repeat" sequences.
|
||||
}
|
||||
}();
|
||||
static constexpr index_t num_lanes = ps_ys_to_xs_adaptor.get_top_dimension_length(number<0>{});
|
||||
static constexpr index_t num_vector_items =
|
||||
container_reduce(TileDistrEnc::detail::ys_lengths_, multiplies<>{}, 1);
|
||||
|
||||
// Check for 0 dims (will break things much earlier but let's have an extra check).
|
||||
static_assert(mat_major_size > 0);
|
||||
static_assert(mat_minor_size > 0);
|
||||
static_assert(num_repeat > 0);
|
||||
static_assert(num_lanes > 0);
|
||||
static_assert(num_vector_items > 0);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calc_matrix_indices_from_lane_vector(index_t lane_inx, index_t vector_inx)
|
||||
{
|
||||
// For some reason the Y dimension is not treated the same as the P dimension and we need to
|
||||
// manually unmerge the Y dimension index into its hidden indices before being able to use
|
||||
// it...
|
||||
array<index_t, TileDistrEnc::NDimY> y_hidden_inx;
|
||||
for(index_t i = TileDistrEnc::NDimY - 1; i >= 0; --i)
|
||||
{
|
||||
y_hidden_inx[i] = vector_inx % TileDistrEnc::detail::ys_lengths_[i];
|
||||
vector_inx /= TileDistrEnc::detail::ys_lengths_[i];
|
||||
}
|
||||
|
||||
const auto ps_ys_idx = container_concat(array<index_t, 1>{lane_inx}, y_hidden_inx);
|
||||
return ps_ys_to_xs_adaptor.calculate_bottom_index(ps_ys_idx);
|
||||
}
|
||||
|
||||
struct LaneVec
|
||||
{
|
||||
index_t lane = -1; // Sentinel for invalid pairs
|
||||
index_t vec = -1;
|
||||
};
|
||||
|
||||
using InverseMap =
|
||||
std::array<std::array<std::array<LaneVec, num_repeat>, mat_minor_size>, mat_major_size>;
|
||||
|
||||
// TODO: In theory this could be done with inverted merge unmerge operations.
|
||||
CK_TILE_HOST_DEVICE static constexpr InverseMap calc_inverse_map()
|
||||
{
|
||||
InverseMap im{};
|
||||
for(index_t l = 0; l < num_lanes; ++l)
|
||||
{
|
||||
for(index_t v = 0; v < num_vector_items; ++v)
|
||||
{
|
||||
auto res = calc_matrix_indices_from_lane_vector(l, v); // Matrix major, minor inx;
|
||||
|
||||
// We assume that repeated matrix elements appear at increasing L and V indices.
|
||||
for(index_t r = 0; r < num_repeat; r++)
|
||||
{
|
||||
auto& lv = im[res[0]][res[1]][r];
|
||||
if(lv.lane < 0)
|
||||
{
|
||||
lv.lane = l; // TODO: c++20 designated initializers
|
||||
lv.vec = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return im;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_dims()
|
||||
{
|
||||
printf("Matrix dims major, minor, repeat = %d %d %d\n",
|
||||
mat_major_size,
|
||||
mat_minor_size,
|
||||
num_repeat);
|
||||
printf("Num lanes, vector items = %d %d\n", num_lanes, num_vector_items);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_mapping()
|
||||
{
|
||||
printf("(lane, vector) item to matrix element\n L | ");
|
||||
for(index_t v = 0; v < num_vector_items; v++)
|
||||
{
|
||||
printf("vec%2d | ", v);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
for(index_t l = 0; l < num_lanes; l++)
|
||||
{
|
||||
printf("%2d | ", l);
|
||||
for(index_t v = 0; v < num_vector_items; v++)
|
||||
{
|
||||
auto res = calc_matrix_indices_from_lane_vector(l, v);
|
||||
printf("%2d %2d | ", res[0], res[1]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_inverse_mapping()
|
||||
{
|
||||
InverseMap im = calc_inverse_map();
|
||||
printf("Matrix element to (lane, vector item). Elements are replicated an additional %d "
|
||||
"time(s) in higher lanes. \n",
|
||||
num_repeat - 1);
|
||||
printf("Mat| ");
|
||||
for(index_t k = 0; k < mat_minor_size; k++)
|
||||
{
|
||||
printf(" %2d | ", k);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
for(index_t m = 0; m < mat_major_size; m++)
|
||||
{
|
||||
printf("%2d | ", m);
|
||||
for(index_t k = 0; k < mat_minor_size; k++)
|
||||
{
|
||||
printf("%2d %2d | ", im[m][k][0].lane, im[m][k][0].vec);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
print_dims();
|
||||
print_mapping();
|
||||
print_inverse_mapping();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum WmmaCtrlFlags
|
||||
* @brief Common wmma control flags for gfx11 and gfx12
|
||||
*/
|
||||
enum struct WmmaCtrlFlags : bool
|
||||
{
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit
|
||||
// Determines which half of the 32-bit accum register to use
|
||||
// Low = bits [15:0]
|
||||
// High = bits[31:16]
|
||||
LOW = false,
|
||||
HIGH = true,
|
||||
|
||||
// Only has an effect on gfx11 / 12 when the input is 8-bit int
|
||||
// Signage indicator of inputs / accum
|
||||
UNSIGNED = false,
|
||||
SIGNED = true
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the architecture-specific WMMA implementations and traits
|
||||
#include "wmma_gfx11.hpp"
|
||||
#include "wmma_gfx12.hpp"
|
||||
#include "wmma_selector.hpp"
|
||||
#include "wmma_traits.hpp"
|
||||
#include "wmma_transforms.hpp"
|
||||
112
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
112
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
|
||||
// - Duplicating A and B inputs
|
||||
// - Handling C / D is always in b32, even for f16 accumulation.
|
||||
// NOTE: Two suggestions:
|
||||
// 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to
|
||||
// duplicate the inputs as needed before calling the actual built-in. This may introduce
|
||||
// some instruction overhead and violate single responsibility clauses, but keeps the logic
|
||||
// contained within the backend wrapper.
|
||||
// 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing
|
||||
// pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform
|
||||
// post-mma could implement D-shuffle transform. This may be cleaner and more flexible than
|
||||
// trying to handle everything in the backend wrappers.
|
||||
//
|
||||
// This current example assumes duplication has already been done, and that C data shuffles have
|
||||
// already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled
|
||||
// data in C.
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @class DefaultWmmaFlags
|
||||
* @brief Generates default WMMA control flags based on data types.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename CDataType>
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
// Generate default flags for signage
|
||||
// Only used currently for integer inputs / accum in gfx11 / gfx12
|
||||
constexpr static WmmaCtrlFlags InputSignA =
|
||||
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags InputSignB =
|
||||
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags AccumSign =
|
||||
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
|
||||
// Generate default flags for accumulator destination bits.
|
||||
// Only used if accumulation size is 16-bit in gfx11
|
||||
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types (duplicated input / b32 accum)
|
||||
using AVecType = ext_vector_t<fp16_t, 16>;
|
||||
using BVecType = ext_vector_t<fp16_t, 16>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
72
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
72
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
173
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
173
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported WMMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct WmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension == 1, which is the base case for the recursive K dimension
|
||||
* search. If no supported instruction is found, falls back to an unsupported pass-through
|
||||
* implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
|
||||
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported WMMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct WmmaOp
|
||||
* @brief Meta-tag for the WMMA operation. This will be used in the MmaOp struct to
|
||||
* identify the operation as an WMMA operation.
|
||||
*/
|
||||
struct WmmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_wmma
|
||||
* @brief Trait to check if MmaOp is an WMMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_wmma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_wmma
|
||||
* @brief MmaOp specialization for WMMA operations, confirming the OpType matches WmmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, WmmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_wmma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DuplicateTransform
|
||||
* @brief Transform to duplicate low register elements to high register elements
|
||||
*/
|
||||
struct DuplicateTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement duplication logic to broadcast low
|
||||
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct PadTransform
|
||||
* @brief Transform to pad data from original type to b32 type
|
||||
*/
|
||||
struct PadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 padding logic.
|
||||
// E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct UnpadTransform
|
||||
* @brief Transform to unpad data from b32 type to original type
|
||||
*/
|
||||
struct UnpadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 logic to unpad 32 to original data type.
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx11
|
||||
* @brief Default MMA transforms for GFX11 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx11
|
||||
{
|
||||
using ATransform = DuplicateTransform;
|
||||
using BTransform = DuplicateTransform;
|
||||
using CTransform = PadTransform;
|
||||
using DTransform = UnpadTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx12
|
||||
* @brief Default MMA transforms for GFX12 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx12
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx11 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx11;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx12 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx12;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
144
include/ck_tile/core/arch/utility.hpp
Normal file
144
include/ck_tile/core/arch/utility.hpp
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: we have "memory" clobber here because this inline asm is used for async copy
|
||||
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
|
||||
}
|
||||
|
||||
// NOTE: this is an immediate value
|
||||
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
|
||||
{
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
|
||||
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
|
||||
|
||||
thread_buffer<T, 2> v;
|
||||
v(0) = bit_cast<T>(x[0]);
|
||||
v(1) = bit_cast<T>(x[1]);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
#if 0
|
||||
return __shfl(v_local, src_lane);
|
||||
#elif 1
|
||||
if constexpr(sizeof(int32_t) > sizeof(T))
|
||||
{
|
||||
union packet
|
||||
{
|
||||
int32_t x;
|
||||
T v;
|
||||
};
|
||||
packet p;
|
||||
p.v = v_local;
|
||||
packet p_remote;
|
||||
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
|
||||
|
||||
return p_remote.v;
|
||||
}
|
||||
else if constexpr(sizeof(int32_t) == sizeof(T))
|
||||
{
|
||||
const int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
|
||||
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
|
||||
using vector_type = thread_buffer<int32_t, elm>;
|
||||
auto vs = bit_cast<vector_type>(v_local);
|
||||
auto vs_remote = vector_type{};
|
||||
static_for<0, elm, 1>{}([&](auto i_e) {
|
||||
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
|
||||
vs_remote(i_e) = tmp;
|
||||
});
|
||||
return bit_cast<T>(vs_remote);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
// per-thread v_flag store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_flag] "v"(v_flag));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
|
||||
{
|
||||
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
|
||||
// per-thread cmp store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_x] "v"(x), [v_y] "v"(y));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
95
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
95
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct workgroup_barrier
|
||||
{
|
||||
CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
|
||||
|
||||
CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0)
|
||||
{
|
||||
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) != value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Reduces power consumption during polling by leveraging wave-level sleep instructions
|
||||
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
// Limit active polling to first wave to reduce memory traffic and power
|
||||
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
|
||||
if(threadIdx.x < wave_size)
|
||||
{
|
||||
uint32_t loaded_value = 0;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
|
||||
while(loaded_value != value)
|
||||
{
|
||||
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
|
||||
// busy-wait
|
||||
__builtin_amdgcn_s_sleep(1);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) < value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// enter critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); }
|
||||
|
||||
// exit critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); }
|
||||
|
||||
CK_TILE_DEVICE void inc(uint32_t offset = 0)
|
||||
{
|
||||
__syncthreads();
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(base_ptr + offset, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t* base_ptr;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
567
include/ck_tile/core/config.hpp
Normal file
567
include/ck_tile/core/config.hpp
Normal file
@@ -0,0 +1,567 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \
|
||||
defined(__gfx9_4_generic__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) || \
|
||||
defined(__gfx1013__) || defined(__gfx10_1_generic__)
|
||||
#define __gfx101__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
|
||||
#define __gfx115__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
#include "hip/hip_version.h"
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
// implementing the "memory address space" attribute
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
|
||||
// WA for https://github.com/ROCm/composable_kernel/issues/1946
|
||||
#if 0
|
||||
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
|
||||
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
|
||||
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
|
||||
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
|
||||
#else
|
||||
#define CK_TILE_GENERIC_ADDR
|
||||
#define CK_TILE_GLOBAL_ADDR
|
||||
#define CK_TILE_LDS_ADDR
|
||||
#define CK_TILE_BUF_RES_ADDR
|
||||
#endif
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
|
||||
#endif
|
||||
|
||||
// in the old rocm period, we have to use tuple array implementation to implement this
|
||||
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
|
||||
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
|
||||
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
// if using tuple-array as thread_buffer implementation, need to support {} brace init
|
||||
// ... with similiar behavior as array
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
|
||||
#else
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define CK_TILE_USE_LAUNCH_BOUNDS 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TIME_KERNEL
|
||||
#define CK_TILE_TIME_KERNEL 1
|
||||
#endif
|
||||
|
||||
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_TILE_MIN_BLOCK_PER_CU 2
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
|
||||
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
|
||||
#define CK_TILE_USE_AMD_BUFFER_STORE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// workaround for ROCm 6.2 and later
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// use llvm builtin bf16 data type after ROCm 6.5
|
||||
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
|
||||
(HIP_VERSION_MAJOR >= 7)
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
|
||||
#else
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
|
||||
defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx101__) || defined(__gfx103__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
// TODO: better solve this inside compiler
|
||||
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
|
||||
#define CK_TILE_WORKAROUND_SWDEV_383542 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_OCP_FP8
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ >= 20 && !(defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WA_ISSUE_2028
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
|
||||
#endif
|
||||
|
||||
// Mark unsupported features with a deprecation warning in debug builds
|
||||
#if defined(NDEBUG)
|
||||
#define CK_TILE_UNSUPPORTED_IMPL(MSG)
|
||||
#else
|
||||
#define CK_TILE_UNSUPPORTED_IMPL(MSG) __attribute__((deprecated(MSG)))
|
||||
#endif
|
||||
|
||||
namespace ck_tile::core {
|
||||
/**
|
||||
* @struct amdgcn_compiler_target_state
|
||||
* @brief Defines compiler states for supported AMDGCN devices.
|
||||
* @var CK_TILE_HOST_COMPILE Indicates if the compilation is for the host.
|
||||
* @var CK_TILE_DEVICE_COMPILE Indicates if the compilation is for AMDGCN device.
|
||||
* @var CK_TILE_ARCH_GFX908 Indicates if the compiler target architecture is GFX908.
|
||||
* @var CK_TILE_ARCH_GFX90A Indicates if the compiler target architecture is GFX90A.
|
||||
* @var CK_TILE_ARCH_GFX942 Indicates if the compiler target architecture is GFX942.
|
||||
* @var CK_TILE_ARCH_GFX950 Indicates if the compiler target architecture is GFX950.
|
||||
* @var CK_TILE_ARCH_GFX1030 Indicates if the compiler target architecture is GFX1030.
|
||||
* @var CK_TILE_ARCH_GFX1031 Indicates if the compiler target architecture is GFX1031.
|
||||
* @var CK_TILE_ARCH_GFX1032 Indicates if the compiler target architecture is GFX1032.
|
||||
* @var CK_TILE_ARCH_GFX1034 Indicates if the compiler target architecture is GFX1034.
|
||||
* @var CK_TILE_ARCH_GFX1035 Indicates if the compiler target architecture is GFX1035.
|
||||
* @var CK_TILE_ARCH_GFX1036 Indicates if the compiler target architecture is GFX1036.
|
||||
* @var CK_TILE_ARCH_GFX10_3_GENERIC Indicates if the compiler target architecture is GFX10.3
|
||||
* generic.
|
||||
* @var CK_TILE_ARCH_GFX1100 Indicates if the compiler target architecture is GFX1100.
|
||||
* @var CK_TILE_ARCH_GFX1101 Indicates if the compiler target architecture is GFX1101.
|
||||
* @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102.
|
||||
* @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151.
|
||||
* @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152.
|
||||
* @var CK_TILE_ARCH_GFX1153 Indicates if the compiler target architecture is GFX1153.
|
||||
* @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic.
|
||||
* @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200.
|
||||
* @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201.
|
||||
* @var CK_TILE_ARCH_GFX12_GENERIC Indicates if the compiler target architecture is GFX12 generic.
|
||||
*/
|
||||
struct amdgcn_compiler_target_state
|
||||
{
|
||||
// Determine if we are compiling for device or host
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
|
||||
static constexpr bool CK_TILE_DEVICE_COMPILE = true;
|
||||
static constexpr bool CK_TILE_HOST_COMPILE = false;
|
||||
#else
|
||||
static constexpr bool CK_TILE_DEVICE_COMPILE = false;
|
||||
static constexpr bool CK_TILE_HOST_COMPILE = true;
|
||||
#endif // __HIP_DEVICE_COMPILE__ && __HIP_DEVICE_COMPILE__
|
||||
|
||||
// GFX9
|
||||
#if defined(__gfx908__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX908 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX908 = false;
|
||||
#endif // __gfx908__
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX90A = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX90A = false;
|
||||
#endif // __gfx90a__
|
||||
|
||||
#if defined(__gfx942__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX942 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX942 = false;
|
||||
#endif // __gfx942__
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX950 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX950 = false;
|
||||
#endif // __gfx950__
|
||||
|
||||
// GFX10
|
||||
#if defined(__gfx1010__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
|
||||
#endif
|
||||
#if defined(__gfx1011__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
|
||||
#endif
|
||||
#if defined(__gfx1012__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
|
||||
#endif
|
||||
#if defined(__gfx1013__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
|
||||
#endif
|
||||
#if defined(__gfx10_1_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
|
||||
#endif // __gfx10_1_generic__
|
||||
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = false;
|
||||
#endif // __gfx1030__
|
||||
|
||||
#if defined(__gfx1031__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1031 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1031 = false;
|
||||
#endif // __gfx1031__
|
||||
|
||||
#if defined(__gfx1032__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1032 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1032 = false;
|
||||
#endif // __gfx1032__
|
||||
|
||||
#if defined(__gfx1034__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1034 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1034 = false;
|
||||
#endif // __gfx1034__
|
||||
|
||||
#if defined(__gfx1035__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1035 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1035 = false;
|
||||
#endif // __gfx1035__
|
||||
|
||||
#if defined(__gfx1036__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1036 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1036 = false;
|
||||
#endif // __gfx1036__
|
||||
|
||||
#if defined(__gfx10_3_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = false;
|
||||
#endif // __gfx10_3_generic__
|
||||
|
||||
// GFX11
|
||||
#if defined(__gfx1100__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1100 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1100 = false;
|
||||
#endif // __gfx1100__
|
||||
|
||||
#if defined(__gfx1101__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1101 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1101 = false;
|
||||
#endif // __gfx1101__
|
||||
|
||||
#if defined(__gfx1102__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1102 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1102 = false;
|
||||
#endif // __gfx1102__
|
||||
|
||||
#if defined(__gfx1103__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1103 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1103 = false;
|
||||
#endif // __gfx1103__
|
||||
|
||||
#if defined(__gfx1150__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1150 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1150 = false;
|
||||
#endif // __gfx1150__
|
||||
|
||||
#if defined(__gfx1151__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1151 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1151 = false;
|
||||
#endif // __gfx1151__
|
||||
|
||||
#if defined(__gfx1152__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1152 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1152 = false;
|
||||
#endif // __gfx1152__
|
||||
|
||||
#if defined(__gfx1153__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1153 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1153 = false;
|
||||
#endif // __gfx1153__
|
||||
|
||||
#if defined(__gfx11_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = false;
|
||||
#endif // __gfx11_generic__
|
||||
|
||||
// GFX12
|
||||
#if defined(__gfx1200__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1200 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1200 = false;
|
||||
#endif // __gfx1200__
|
||||
|
||||
#if defined(__gfx1201__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1201 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1201 = false;
|
||||
#endif // __gfx1201__
|
||||
|
||||
#if defined(__gfx12_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = false;
|
||||
#endif // __gfx12_generic__
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper to count the number of times an item is contained within a list of values
|
||||
* @tparam T The type of the search value
|
||||
* @tparam Ts The types of the search list values
|
||||
* @param search The value to search for
|
||||
* @param searchList The list of values to search in
|
||||
* @return true if the search value is in the search list, false otherwise
|
||||
*/
|
||||
template <typename T, typename... Ts>
|
||||
// TODO: c++20 concept requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >=
|
||||
// 1))
|
||||
CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... searchList)
|
||||
{
|
||||
static_assert((std::is_convertible<Ts, T>::value && ...),
|
||||
"All search list values must be convertible to the search value type");
|
||||
static_assert(sizeof...(Ts) >= 1, "At least one value must be provided to search in");
|
||||
|
||||
return (static_cast<uint32_t>(search == static_cast<T>(searchList)) + ...);
|
||||
}
|
||||
|
||||
#define CK_TILE_COMPILER_TARGETS_LIST \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX908, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_3_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1100, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1101, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1102, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1103, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1153, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX12_GENERIC
|
||||
|
||||
// Sanity check: make sure only one target architecture is defined during device compile
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
|
||||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 1u,
|
||||
"Only one target architecture can be defined during device compile");
|
||||
|
||||
// Sanity check: make sure no device target architecture is defined during host compile
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
|
||||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 0u,
|
||||
"No device target architecture can be defined during host compile");
|
||||
|
||||
} // namespace ck_tile::core
|
||||
307
include/ck_tile/core/container/array.hpp
Normal file
307
include/ck_tile/core/container/array.hpp
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use aggregate initialization for this type
|
||||
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
|
||||
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
|
||||
// use make_array_with({...}) to construct an array with compatible behavior as old ck
|
||||
// TODO: manually added constructor same as old ck
|
||||
/**
|
||||
* @brief A fixed-size array container similar to std::array with additional utilities.
|
||||
*
|
||||
* This template class provides a lightweight fixed-size array with value semantics,
|
||||
* supporting both host and device functionality for GPU programming. It includes
|
||||
* specialized initialization methods and type punning capabilities.
|
||||
*
|
||||
* @tparam T_ The type of elements in the array
|
||||
* @tparam N_ The fixed number of elements in the array
|
||||
*
|
||||
* @note This implementation provides additional features beyond std::array:
|
||||
* - GPU compatibility via CK_TILE_HOST_DEVICE macros
|
||||
* - Type punning via get_as() and set_as() methods
|
||||
* - Various specialized access methods
|
||||
* - Specialized initialization behaviors
|
||||
*
|
||||
* The initializer_list constructor fills remaining elements with the last value
|
||||
* provided if the list size is smaller than N, which is different than std::array.
|
||||
*/
|
||||
template <typename T_, index_t N_>
|
||||
struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
// TODO: do we need this?
|
||||
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
|
||||
// union {
|
||||
value_type data[N];
|
||||
// bulk_type __content;
|
||||
//};
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
|
||||
{
|
||||
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
|
||||
static_assert(list_size <= N, "out of bound");
|
||||
|
||||
index_t i = 0;
|
||||
value_type vlast = value_type{};
|
||||
|
||||
for(const value_type& val : ilist)
|
||||
{
|
||||
data[i] = val;
|
||||
vlast = val;
|
||||
++i;
|
||||
}
|
||||
for(; i < N; ++i)
|
||||
{
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y,
|
||||
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
|
||||
std::is_constructible_v<Y, value_type>>>
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
// clang-format off
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
#if 0
|
||||
template <typename ArrayLike>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
|
||||
{
|
||||
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
|
||||
for(index_t i = 0; i < size(); ++i)
|
||||
{
|
||||
data[i] = arr[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
// type punning (strict aliasing) member functions for read/write
|
||||
// aliasing this array of type "T", "N" elements
|
||||
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
|
||||
#define AR_AS_COM_() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
#undef AR_AS_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// empty Array
|
||||
|
||||
/// @brief Specialization of array container for zero elements.
|
||||
///
|
||||
/// This is a specialization of the array container template for the case where the number of
|
||||
/// elements is 0. It provides the same interface as the general array template, but with operations
|
||||
/// appropriate for an empty array.
|
||||
///
|
||||
/// @tparam T The type of elements stored in the array (not used in this specialization but
|
||||
/// maintained for API consistency).
|
||||
template <typename T>
|
||||
struct array<T, 0>
|
||||
{
|
||||
using value_type = T;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
|
||||
{
|
||||
printf("array{size: %ld, data: [", static_cast<long>(N));
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(a[i]);
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
|
||||
{
|
||||
printf("array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<array<T, N>, void>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
namespace details {
|
||||
template <class>
|
||||
struct is_ref_wrapper : std::false_type
|
||||
{
|
||||
};
|
||||
template <class T>
|
||||
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
|
||||
|
||||
template <class D, class...>
|
||||
struct return_type_helper
|
||||
{
|
||||
using type = D;
|
||||
};
|
||||
template <class... Ts>
|
||||
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
|
||||
{
|
||||
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
|
||||
"Ts cannot contain reference_wrappers when D is void");
|
||||
};
|
||||
|
||||
template <class D, class... Ts>
|
||||
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
|
||||
} // namespace details
|
||||
|
||||
template <typename D = void, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
|
||||
{
|
||||
return {std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
// // make empty array
|
||||
// template <typename T>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
// {
|
||||
// return array<T, 0>{};
|
||||
// }
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
|
||||
{
|
||||
return array<T, Size>(ilist);
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
for(index_t i = 0; i < Size; ++i)
|
||||
{
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
|
||||
{
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
static_assert(N <= X::size(), "");
|
||||
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
499
include/ck_tile/core/container/container_helper.hpp
Normal file
499
include/ck_tile/core/container/container_helper.hpp
Normal file
@@ -0,0 +1,499 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& new2old)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [new_pos, old_pos] : new2old)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& old2new)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [old_pos, new_pos] : old2new)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
// reorder tuple
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder sequence
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
|
||||
sequence<IRs...> /* old2new */)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i + number<IStep>{}, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<IBegin>{}, init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename ROld,
|
||||
index_t I,
|
||||
index_t IEnd,
|
||||
index_t IStep>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
|
||||
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
if constexpr(IEnd > IBegin)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return init;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
|
||||
{
|
||||
#if 0
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
});
|
||||
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
#else
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
for(index_t i = NSize - 1; i > 0; --i)
|
||||
{
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
}
|
||||
|
||||
y(0) = r;
|
||||
|
||||
return y;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
|
||||
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
return container_reverse_exclusive_scan_impl(
|
||||
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
|
||||
template <typename... Xs, typename Reduce, typename TData>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_array<T>(arr[Is]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_tuple(tup[number<Is>{}]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return tuple<>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
for(index_t i = 0; i < picks.size(); ++i)
|
||||
{
|
||||
y(picks[i]) = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
|
||||
{
|
||||
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
}
|
||||
|
||||
// return the index of first occurance in the sequence.
|
||||
// return seq.size(), if not found
|
||||
template <index_t... Is>
|
||||
constexpr index_t container_find(sequence<Is...> seq, index_t value)
|
||||
{
|
||||
for(auto i = 0; i < seq.size(); i++)
|
||||
{
|
||||
if(seq[i] == value)
|
||||
return i;
|
||||
}
|
||||
|
||||
return seq.size();
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
{
|
||||
using Seq = sequence<Is...>;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Seq::at(i);
|
||||
return number<tmp>{};
|
||||
},
|
||||
number<Seq::size()>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, a_size, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#else
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
163
include/ck_tile/core/container/map.hpp
Normal file
163
include/ck_tile/core/container/map.hpp
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// naive map
|
||||
template <typename key, typename data, index_t max_size = 128>
|
||||
struct map
|
||||
{
|
||||
using pair_type = tuple<key, data>;
|
||||
using impl_type = array<pair_type, max_size>;
|
||||
|
||||
impl_type impl_;
|
||||
index_t size_;
|
||||
|
||||
struct iterator
|
||||
{
|
||||
impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
struct const_iterator
|
||||
{
|
||||
const impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
|
||||
{
|
||||
for(index_t i = 0; i < size(); i++)
|
||||
{
|
||||
if(impl_[i].template at<0>() == k)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return size_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
|
||||
{
|
||||
return const_iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
|
||||
{
|
||||
return iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
|
||||
{
|
||||
const auto it = find(k);
|
||||
|
||||
// FIXME
|
||||
// assert(it.pos_ < size());
|
||||
|
||||
return impl_[it.pos_].template at<1>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
|
||||
{
|
||||
auto it = find(k);
|
||||
|
||||
// if entry not found
|
||||
if(it.pos_ == size())
|
||||
{
|
||||
impl_(it.pos_).template at<0>() = k;
|
||||
size_++;
|
||||
}
|
||||
|
||||
// FIXME
|
||||
// assert(size_ <= max_size);
|
||||
|
||||
return impl_(it.pos_).template at<1>();
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
|
||||
{
|
||||
return const_iterator{impl_, size_};
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
};
|
||||
|
||||
template <typename key, typename data, index_t max_size>
|
||||
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
|
||||
{
|
||||
printf("map{size_: %d, impl_: [", m.size_);
|
||||
for(const auto& [k, d] : m)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: this structure is not intented to be used by user
|
||||
template <index_t MaxSize>
|
||||
struct meta_data_buffer
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
|
||||
: buffer_{}, size_{0}
|
||||
{
|
||||
push(x, xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
|
||||
{
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
buffer_(size_) = tmp[i];
|
||||
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
|
||||
{
|
||||
push(x);
|
||||
push(xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
|
||||
{
|
||||
T data;
|
||||
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
data = ck_tile::bit_cast<T>(tmp);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
auto data = ck_tile::bit_cast<T>(tmp);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
//
|
||||
array<std::byte, MaxSize> buffer_;
|
||||
index_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
100
include/ck_tile/core/container/multi_index.hpp
Normal file
100
include/ck_tile/core/container/multi_index.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Don't use tihs directly. This is for old CK's internal usage,
|
||||
// in the future always use array instead
|
||||
template <index_t N>
|
||||
using multi_index = array<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = index_t * multi_index
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
|
||||
{
|
||||
multi_index<NSize> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = multi_index * index_t
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1378
include/ck_tile/core/container/sequence.hpp
Normal file
1378
include/ck_tile/core/container/sequence.hpp
Normal file
File diff suppressed because it is too large
Load Diff
78
include/ck_tile/core/container/span.hpp
Normal file
78
include/ck_tile/core/container/span.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
|
||||
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
|
||||
// TODO: do we need in device consider this is pointer?
|
||||
template <typename T>
|
||||
class span
|
||||
{
|
||||
public:
|
||||
using element_type = T;
|
||||
using value_type = std::remove_cv_t<element_type>;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = element_type*;
|
||||
using const_pointer = const element_type*;
|
||||
using reference = element_type&;
|
||||
using const_reference = const element_type&;
|
||||
using iterator = pointer;
|
||||
using const_iterator = pointer;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
|
||||
{
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
|
||||
: span(arr.data(), N)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
|
||||
: span(container.data(), container.size())
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
|
||||
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
|
||||
{
|
||||
return *(begin() + idx);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
pointer ptr_;
|
||||
size_type size_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
32
include/ck_tile/core/container/static_array.hpp
Normal file
32
include/ck_tile/core/container/static_array.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Fixed-size array with aggregate initialization for constexpr contexts.
|
||||
*
|
||||
* Unlike ck_tile::array, this has no custom constructors, making it a literal type
|
||||
* suitable for constexpr evaluation and GPU kernel code. Use ck_tile::array when
|
||||
* constructors or non-trivial initialization are needed.
|
||||
* Use aggregate initialization: static_array<int, 3> arr{1, 2, 3};
|
||||
*/
|
||||
template <typename T, index_t N>
|
||||
struct static_array
|
||||
{
|
||||
// Public aggregate initialization makes this a literal type.
|
||||
// N == 0 uses size 1 to avoid zero-length arrays (non-standard).
|
||||
T elems[N > 0 ? N : 1];
|
||||
|
||||
// Basic constexpr accessors
|
||||
CK_TILE_HOST_DEVICE constexpr const T& operator[](index_t i) const { return elems[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr T& operator[](index_t i) { return elems[i]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return N; }
|
||||
};
|
||||
} // namespace ck_tile
|
||||
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = tuple_array<T, N>;
|
||||
|
||||
#else
|
||||
|
||||
// consider mark this struct as deprecated
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = array<T, N>;
|
||||
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
#if 0
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty statically_indexed_array
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
176
include/ck_tile/core/container/thread_buffer.hpp
Normal file
176
include/ck_tile/core/container/thread_buffer.hpp
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = tuple_array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_tuple(ts...);
|
||||
}
|
||||
#else
|
||||
|
||||
#if 0
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template<typename T_, index_t N_>
|
||||
struct thread_buffer {
|
||||
using value_type = remove_cvref_t<T_>;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
value_type data[N];
|
||||
|
||||
// TODO: this ctor can't ignore
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto i) { data[i] = o; }
|
||||
);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE auto & get() {return data; }
|
||||
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
|
||||
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
static_assert(N % kSPerX == 0);
|
||||
|
||||
union {
|
||||
thread_buffer<X_, N / kSPerX> data {};
|
||||
// tuple_array<value_type, kSPerX> sub_data;
|
||||
value_type sub_data[N];
|
||||
} vx;
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto j) { vx.sub_data[j] = data[j]; });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data {};
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx;
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data;
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx {x};
|
||||
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
|
||||
#undef TB_COMMON_AS
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = typename T::type;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
872
include/ck_tile/core/container/tuple.hpp
Normal file
872
include/ck_tile/core/container/tuple.hpp
Normal file
@@ -0,0 +1,872 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <utility>
|
||||
#include <initializer_list>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
#ifndef CK_TILE_TUPLE_IMPL
|
||||
#define CK_TILE_TUPLE_IMPL 1
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl;
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
|
||||
|
||||
namespace impl {
|
||||
|
||||
// the place where content is stored
|
||||
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
|
||||
struct tuple_object
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
#endif
|
||||
T element;
|
||||
};
|
||||
|
||||
// NOTE: we return a instance(not a reference) if content is empty
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr const T&
|
||||
getv([[clang::lifetimebound]] const tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv([[clang::lifetimebound]] tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
}
|
||||
|
||||
template <typename index_seq, typename... T>
|
||||
struct tuple_base;
|
||||
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#define _ILE() (std::initializer_list<U>{}.size() - 1)
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
|
||||
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
|
||||
{
|
||||
}
|
||||
#undef _ILE
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
|
||||
: tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <class U,
|
||||
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
|
||||
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
|
||||
"wrong! inconsistent size");
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <class... T>
|
||||
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr auto size() { return sizeof...(T); }
|
||||
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
|
||||
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <
|
||||
typename U,
|
||||
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U,
|
||||
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; }
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
|
||||
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
|
||||
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename... T>
|
||||
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
|
||||
{
|
||||
printf("tuple<");
|
||||
if constexpr(sizeof...(T) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
|
||||
if(!first)
|
||||
printf(", ");
|
||||
print(t.get(i));
|
||||
first = false;
|
||||
});
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>, void>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
};
|
||||
|
||||
// template <class... T>
|
||||
// CK_TILE_HOST_DEVICE constexpr
|
||||
// tuple<T...>
|
||||
// make_tuple(T const&... t)
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
}
|
||||
});
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
// here xs is always a lvalue as function arg
|
||||
// Xs may deduced as (e.g try to pass in a integer in following cases)
|
||||
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
|
||||
// 2). if pass in a const lvalue -> Xs is "const int &"
|
||||
// 3). if pass in a non-const lvalue -> Xs is "int &"
|
||||
// so the return type of std::forward will dependes on Xs
|
||||
// 1). std::forward -> int&&
|
||||
// 2). std::forward -> const int&
|
||||
// 3). std::forward -> int&
|
||||
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
template <typename... Args>
|
||||
constexpr tuple<Args&...> tie(Args&... args) noexcept
|
||||
{
|
||||
return {args...};
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// be very careful using this type (because we want the internal type)
|
||||
// template deduction will fail if infering the inner type
|
||||
// e.g.
|
||||
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
|
||||
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
|
||||
// -> compiler will fail to deduce this type, because this is under non-deduced context
|
||||
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
|
||||
// contexts")
|
||||
//
|
||||
// -> use this instead
|
||||
// template<typename Tup> void foo(const Tup&) {}
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename F, index_t... ids>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence<ids...>)
|
||||
{
|
||||
return make_tuple(f(number<ids>{})...);
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
return generate_tuple_for(f, make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return tie(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
|
||||
const tuple<Y&...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
// Support any number of tuples to concat (also 1)
|
||||
template <typename... X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
|
||||
{
|
||||
return tx;
|
||||
}
|
||||
|
||||
template <typename... X, typename... Tuples>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
|
||||
{
|
||||
return concat_tuple(tx, concat_tuple(tuples...));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple, index_t... Is>
|
||||
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple>
|
||||
constexpr decltype(auto) apply(F&& f, Tuple&& t)
|
||||
{
|
||||
constexpr index_t N = std::decay_t<Tuple>::size();
|
||||
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return concat_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make sure F return at least a tuple
|
||||
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
|
||||
// this function will return
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::embed_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
|
||||
{
|
||||
return make_tuple(t);
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
|
||||
{
|
||||
if constexpr(Depth == MaxDepth)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
else
|
||||
{
|
||||
return unpack(
|
||||
[&](auto&&... ts) {
|
||||
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
|
||||
},
|
||||
t);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using Idx = number<tuple<Ts...>::size() - i - 1>;
|
||||
return t.at(Idx{});
|
||||
},
|
||||
number<tuple<Ts...>::size()>{});
|
||||
}
|
||||
|
||||
// Reduce tuple values in specific range using Function
|
||||
template <index_t Idx, index_t End, typename F, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
|
||||
{
|
||||
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
|
||||
if constexpr(Idx + 1 == End)
|
||||
{
|
||||
return t.at(number<Idx>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
|
||||
{
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
|
||||
{
|
||||
return max(tuple_depth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
{
|
||||
constexpr index_t n0 = sizeof...(Seqs);
|
||||
|
||||
constexpr index_t max_n1 = [&] {
|
||||
index_t max_n1_ = 0;
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
|
||||
});
|
||||
|
||||
return max_n1_;
|
||||
}();
|
||||
|
||||
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
|
||||
});
|
||||
|
||||
return a_of_a;
|
||||
}
|
||||
|
||||
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
|
||||
// is the alias of the latter. This is because compiler cannot infer the NSize if
|
||||
// using MultiIndex<NSize>
|
||||
// TODO: how to fix this?
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = scalar * MultiIndex
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = MultiIndex * scalar
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include <tuple>
|
||||
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
||||
namespace std {
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: std::tuple_element<I, const std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
#pragma clang diagnostic pop
|
||||
443
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
443
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
@@ -0,0 +1,443 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class bf16_rounding_mode
|
||||
{
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
constexpr bfloat16_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const unsigned int& x)
|
||||
: data(float_to_bf16_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<bfloat16_t>
|
||||
{
|
||||
using type = ushort;
|
||||
};
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
#endif
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
if(~bits & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(bits & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
bits |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(bits >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST
|
||||
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rtn_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
|
||||
static constexpr uint32_t FP32_NAN = 0x7fff0000;
|
||||
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
|
||||
|
||||
#if defined(__GFX9__)
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
#else
|
||||
uint32_t check_nan;
|
||||
#endif
|
||||
uint32_t tmp;
|
||||
asm volatile("\n \
|
||||
v_cmp_u_f32 %0, %2, %2 \n \
|
||||
v_bfe_u32 %1, %2, 16, 1 \n \
|
||||
v_add3_u32 %1, %2, %1, %3 \n \
|
||||
v_cndmask_b32 %2, %1, %4, %0 \n \
|
||||
v_lshrrev_b32 %2, 16, %2 \n \
|
||||
"
|
||||
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
|
||||
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
|
||||
|
||||
return uint16_t(u.int32);
|
||||
}
|
||||
|
||||
// TODO: do we need this on host?
|
||||
CK_TILE_HOST
|
||||
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rta_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
struct
|
||||
{
|
||||
uint16_t lo;
|
||||
uint16_t hi;
|
||||
};
|
||||
} u = {f};
|
||||
|
||||
const uint32_t low_nan = 0x7fff;
|
||||
const uint32_t hi_nan = 0x7fff0000;
|
||||
|
||||
#if defined(__GFX9__)
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
#else
|
||||
uint32_t check_nan;
|
||||
#endif
|
||||
|
||||
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
|
||||
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
|
||||
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
|
||||
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
|
||||
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
|
||||
|
||||
// Note: in above code snipet, we use hi 16 bit
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
return float_to_bf16_rta_asm(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x)
|
||||
{
|
||||
return static_cast<double>(bf16_to_float_raw(x));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
// Use builtin bfloat16 conversion only on gfx950 as its predecessors do not support bf16 cvt
|
||||
// instructions, resulting in suboptimal performance; Add host side marcro check for consistency
|
||||
// during accuracy tests.
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16 && (defined(__gfx950__) || defined(CK_GFX950_SUPPORT))
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeeeeee mmmmmmm
|
||||
// 0 01111110 0000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<bfloat16_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x)
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = bit_cast<bf16_raw_t>(x);
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t sqrt(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x)
|
||||
{
|
||||
return bf16x2_t{float_to_bf16<rounding>(x.x), float_to_bf16<rounding>(x.y)};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
106
include/ck_tile/core/numeric/e8m0.hpp
Normal file
106
include/ck_tile/core/numeric/e8m0.hpp
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Unsigned representation of a conventional biased Float32 exponent.
|
||||
*
|
||||
* bias = 127;
|
||||
*
|
||||
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
|
||||
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
|
||||
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
|
||||
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
|
||||
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
|
||||
* E8M0_MIN = 0b00000000; => 2^-127
|
||||
* E8M0_MAX = 0b11111110; => 2^127
|
||||
* E8M0_NAN = 0b11111111; => NaN
|
||||
*/
|
||||
|
||||
struct e8m0_bexp_t
|
||||
{
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
|
||||
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
|
||||
|
||||
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
|
||||
};
|
||||
|
||||
using e8m0_t = e8m0_bexp_t;
|
||||
using e8m0_raw_t = typename e8m0_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_t>
|
||||
{
|
||||
using bitwise_type = e8m0_raw_t;
|
||||
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_t>
|
||||
{
|
||||
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
|
||||
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
|
||||
static constexpr e8m0_raw_t binary_nan = 0b11111111;
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
|
||||
{
|
||||
using traits = numeric_traits<float>;
|
||||
if(data == numeric<e8m0_t>::binary_nan)
|
||||
{
|
||||
return std::numeric_limits<float>::signaling_NaN();
|
||||
}
|
||||
else if(data == 0)
|
||||
{
|
||||
return std::numeric_limits<float>::min();
|
||||
}
|
||||
else
|
||||
{
|
||||
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
1120
include/ck_tile/core/numeric/float8.hpp
Normal file
1120
include/ck_tile/core/numeric/float8.hpp
Normal file
File diff suppressed because it is too large
Load Diff
410
include/ck_tile/core/numeric/half.hpp
Normal file
410
include/ck_tile/core/numeric/half.hpp
Normal file
@@ -0,0 +1,410 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp16_hip_t = _Float16; // most of hip internal function use this type
|
||||
using fp16_raw_t = uint16_t;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = fp16_raw_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// constructor
|
||||
constexpr half_t() : data{} {}
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const unsigned int& x)
|
||||
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
{
|
||||
return static_cast<int>(fp16_to_float_hip(to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<half_t>
|
||||
{
|
||||
using type = _Float16;
|
||||
};
|
||||
|
||||
using fp16_t = half_t;
|
||||
using fp16_raw_t = typename half_t::raw_type;
|
||||
#else
|
||||
using fp16_t = _Float16;
|
||||
using half_t = _Float16;
|
||||
using fp16_raw_t = ushort;
|
||||
#endif
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
{
|
||||
return static_cast<double>(fp16_to_float_hip(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeee mmmmmmmmmm
|
||||
// 0 01110 0000000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t zero()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint16_t abs_mask = 0x7FFF;
|
||||
static constexpr uint16_t Inf = 0x7C00;
|
||||
static constexpr uint16_t NegInf = 0xFC00;
|
||||
static constexpr uint16_t NaN = 0x7C01;
|
||||
static constexpr uint16_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// arithmetic
|
||||
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
|
||||
{
|
||||
return __heq(x.to_fp16(), y.to_fp16());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
#endif
|
||||
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t c;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
|
||||
return c;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16x2_t fp32x2_to_fp16x2(const fp32x2_t& x)
|
||||
{
|
||||
return fp16x2_t{float_to_fp16(x.x), float_to_fp16(x.y)};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
103
include/ck_tile/core/numeric/int8.hpp
Normal file
103
include/ck_tile/core/numeric/int8.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use int8_t directly for int8 arithemetic
|
||||
// here one can use ck_tile::int8_t to access original int8_t
|
||||
using int8_t = int8_t;
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<int8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
|
||||
template <>
|
||||
struct numeric_traits<int8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
|
||||
|
||||
} // namespace ck_tile
|
||||
14
include/ck_tile/core/numeric/integer.hpp
Normal file
14
include/ck_tile/core/numeric/integer.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using int32_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
98
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
98
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <auto v>
|
||||
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
|
||||
{
|
||||
printf("%ld", static_cast<long>(v));
|
||||
}
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
using number = constant<v>;
|
||||
|
||||
template <long_index_t v>
|
||||
using long_number = constant<v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
|
||||
{ \
|
||||
return constant<(OP x)>{}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_BINARY_OP(OP) \
|
||||
template <auto x, auto y> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
|
||||
{ \
|
||||
return constant<(x OP y)>{}; \
|
||||
}
|
||||
|
||||
CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
CK_TILE_BINARY_OP(*)
|
||||
CK_TILE_BINARY_OP(/)
|
||||
CK_TILE_BINARY_OP(%)
|
||||
CK_TILE_BINARY_OP(&)
|
||||
CK_TILE_BINARY_OP(|)
|
||||
CK_TILE_BINARY_OP(^)
|
||||
CK_TILE_BINARY_OP(<<)
|
||||
CK_TILE_BINARY_OP(>>)
|
||||
CK_TILE_BINARY_OP(&&)
|
||||
CK_TILE_BINARY_OP(||)
|
||||
CK_TILE_BINARY_OP(==)
|
||||
CK_TILE_BINARY_OP(!=)
|
||||
CK_TILE_BINARY_OP(>)
|
||||
CK_TILE_BINARY_OP(<)
|
||||
CK_TILE_BINARY_OP(>=)
|
||||
CK_TILE_BINARY_OP(<=)
|
||||
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
struct is_constant : std::false_type
|
||||
{
|
||||
};
|
||||
template <auto v>
|
||||
struct is_constant<constant<v>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_constant_v = is_constant<T>::value;
|
||||
} // namespace ck_tile
|
||||
1472
include/ck_tile/core/numeric/math.hpp
Normal file
1472
include/ck_tile/core/numeric/math.hpp
Normal file
File diff suppressed because it is too large
Load Diff
218
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
218
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
// modify from include/ck/utility/mxfp_utils.hpp
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils : numeric_traits<T>
|
||||
{
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename traits::bitwise_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr raw_type get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr raw_type get_exponent(const T& x)
|
||||
{
|
||||
return get_exponent(bit_cast<raw_type>(x));
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
}
|
||||
static constexpr bool is_subnormal(raw_type x)
|
||||
{
|
||||
return get_exponent(x) == _numeric::binary_zero;
|
||||
}
|
||||
// TODO: replace double with template arg?
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(raw_type i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
}
|
||||
return mantissa;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant * scale, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
value /= scale;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
bitwise_type sign =
|
||||
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
|
||||
bitwise_type exp =
|
||||
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
|
||||
(numeric_traits<float>::bias - numeric_traits<T>::bias);
|
||||
bitwise_type mantissa =
|
||||
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
|
||||
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
|
||||
mant_prev--;
|
||||
|
||||
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
uint32_t prev_bit =
|
||||
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
|
||||
mant_prev;
|
||||
|
||||
float prev_val = bit_cast<float>(prev_bit);
|
||||
float diff = max_value - prev_val;
|
||||
|
||||
float actual_max = max_value + (diff / 2);
|
||||
|
||||
if(std::abs(value) < actual_max)
|
||||
{
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!numeric<T>::has_inf())
|
||||
{
|
||||
|
||||
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
exp++;
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int mfmt = numeric_traits<float>::mant;
|
||||
uint32_t x;
|
||||
x = bit_cast<uint32_t>(value);
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int32_t exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
head = x & numeric_traits<float>::head_mask;
|
||||
mantissa = x & numeric_traits<float>::mant_mask;
|
||||
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
|
||||
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
|
||||
bias = numeric_traits<float>::bias;
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0b0;
|
||||
}
|
||||
|
||||
const int mini_bias = numeric_traits<T>::bias;
|
||||
const int mini_denormal_act_exponent = 1 - mini_bias;
|
||||
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
bool is_subnorm = false;
|
||||
|
||||
if(exponent == 0)
|
||||
{
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= mini_denormal_act_exponent)
|
||||
{
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
exponent_diff = 0;
|
||||
}
|
||||
mantissa += (1UL << mfmt);
|
||||
}
|
||||
|
||||
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
|
||||
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
|
||||
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
|
||||
|
||||
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
|
||||
|
||||
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
|
||||
{
|
||||
// closer to 0
|
||||
if(std::abs(value) <= std::abs(min_subnorm - value))
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
else
|
||||
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
|
||||
}
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
|
||||
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
|
||||
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
|
||||
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1UL << mfmt) & mantissa)
|
||||
{
|
||||
out_exponent = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1UL << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - numeric_traits<T>::mant);
|
||||
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
{
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
}
|
||||
|
||||
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
|
||||
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(out_exponent << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
13
include/ck_tile/core/numeric/null_type.hpp
Normal file
13
include/ck_tile/core/numeric/null_type.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_type
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
194
include/ck_tile/core/numeric/numeric.hpp
Normal file
194
include/ck_tile/core/numeric/numeric.hpp
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this struct has the information of
|
||||
// 1. limit of a certain type, simliar to std::numeric_limits
|
||||
// 2. some pre-defined value, zero, one...
|
||||
//
|
||||
template <typename T>
|
||||
struct numeric
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T log2e()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
|
||||
{
|
||||
return static_cast<T>(C_LOG2E);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // TODO: integer?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits
|
||||
{
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
|
||||
static_cast<float>(numeric<type_>::epsilon()); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
523
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
523
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
@@ -0,0 +1,523 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CK_TILE_FP4_CVT_DEVICE 1
|
||||
#else
|
||||
#define CK_TILE_FP4_CVT_DEVICE 0
|
||||
#endif
|
||||
|
||||
#define TEST_convert_with_table 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
|
||||
#else
|
||||
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
|
||||
#endif
|
||||
|
||||
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
|
||||
// Some compilers don't allow operator[] in constant expressions for vector types.
|
||||
// We use bit_cast to a trivially copyable representation to extract lanes.
|
||||
namespace detail {
|
||||
struct fp16x2_repr
|
||||
{
|
||||
_Float16 e[2];
|
||||
};
|
||||
struct bf16x2_repr
|
||||
{
|
||||
bfloat16_t e[2];
|
||||
};
|
||||
struct fp32x2_repr
|
||||
{
|
||||
float e[2];
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp16x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp16x2_repr>(v).e[1];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane0(const bf16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<bf16x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane1(const bf16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<bf16x2_repr>(v).e[1];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float lane0(const fp32x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp32x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float lane1(const fp32x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp32x2_repr>(v).e[1];
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
struct pk_float4_e2m1_t;
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
|
||||
: data{float_to_pk_fp4(init, scale)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr type get() const { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const
|
||||
{
|
||||
return _unpack(number<I>{});
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0,
|
||||
const pk_float4_e2m1_t& x1)
|
||||
{
|
||||
return _pack(x0.get(), x1.get());
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr type _unpack(number<I>) const;
|
||||
CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1)
|
||||
{
|
||||
return (x1 << 4) | (x0 & 0b00001111);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table
|
||||
static constexpr float e2m1_to_fp32_table[16] = {
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
|
||||
static constexpr fp16_t e2m1_to_fp16_table[16] = {
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
// FP8 EM4E3 (OCP) representation
|
||||
static constexpr fp8_t e2m1_to_fp8_table[16] = {
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // 0
|
||||
fp8_t(static_cast<uint8_t>(0x30)), // 0.5
|
||||
fp8_t(static_cast<uint8_t>(0x38)), // 1
|
||||
fp8_t(static_cast<uint8_t>(0x3C)), // 1.5
|
||||
fp8_t(static_cast<uint8_t>(0x40)), // 2
|
||||
fp8_t(static_cast<uint8_t>(0x44)), // 3
|
||||
fp8_t(static_cast<uint8_t>(0x48)), // 4
|
||||
fp8_t(static_cast<uint8_t>(0x4C)), // 6
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // -0
|
||||
fp8_t(static_cast<uint8_t>(0xB0)), // -0.5
|
||||
fp8_t(static_cast<uint8_t>(0xB8)), // -1
|
||||
fp8_t(static_cast<uint8_t>(0xBC)), // -1.5
|
||||
fp8_t(static_cast<uint8_t>(0xC0)), // -2
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -3
|
||||
fp8_t(static_cast<uint8_t>(0xC8)), // -4
|
||||
fp8_t(static_cast<uint8_t>(0xCC)) // -6
|
||||
};
|
||||
#else // CK_TILE_USE_FNUZ_FP8
|
||||
// FP8 E4M3 FNUZ
|
||||
static constexpr fp8_t e2m1_to_fp8_table[16] = {
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // 0
|
||||
fp8_t(static_cast<uint8_t>(0x38)), // 0.5
|
||||
fp8_t(static_cast<uint8_t>(0x40)), // 1
|
||||
fp8_t(static_cast<uint8_t>(0x44)), // 1.5
|
||||
fp8_t(static_cast<uint8_t>(0x48)), // 2
|
||||
fp8_t(static_cast<uint8_t>(0x4C)), // 3
|
||||
fp8_t(static_cast<uint8_t>(0x50)), // 4
|
||||
fp8_t(static_cast<uint8_t>(0x54)), // 6
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // -0
|
||||
fp8_t(static_cast<uint8_t>(0xB8)), // -0.5
|
||||
fp8_t(static_cast<uint8_t>(0xC0)), // -1
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -1.5
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -2
|
||||
fp8_t(static_cast<uint8_t>(0xCC)), // -3
|
||||
fp8_t(static_cast<uint8_t>(0xD0)), // -4
|
||||
fp8_t(static_cast<uint8_t>(0xD4)) // -6
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
using pk_fp4_t = pk_float4_e2m1_t;
|
||||
using pk_fp4_raw_t = typename pk_fp4_t::type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_fp4_t>
|
||||
{
|
||||
using bitwise_type = pk_fp4_raw_t;
|
||||
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 1;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_fp4_t>
|
||||
{
|
||||
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
|
||||
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
|
||||
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
|
||||
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number<I>) const
|
||||
{
|
||||
static_assert(I < 2, "Index is out of range.");
|
||||
if constexpr(I == 1)
|
||||
return (data >> 4);
|
||||
else
|
||||
return data & 0b00001111;
|
||||
}
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
|
||||
// TODO: consider replace this macro to improve performance
|
||||
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
{
|
||||
fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
{
|
||||
fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
{
|
||||
bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return T{};
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4_raw_t pf4[4];
|
||||
} cvt{0};
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
|
||||
cvt.u32, detail::lane0(src), detail::lane1(src), scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return cvt.pf4[0];
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data, scale);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data, scale);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make it generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
auto res = convert_to_type<pk_fp4_t>(x, scale);
|
||||
return pk_fp4_t::_pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
auto res = float_to_mxfp4(type_convert<float>(x), scale);
|
||||
return pk_fp4_t::_pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
auto res = float_to_mxfp4(type_convert<float>(x), scale);
|
||||
return pk_fp4_t::_pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data, scale);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data, scale);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale),
|
||||
convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale)};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data, scale);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data, scale);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
|
||||
{
|
||||
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
|
||||
// would be better than the naive implementation below
|
||||
// #if CK_TILE_FP4_CVT_DEVICE
|
||||
// return impl::_from_f4<fp8_t>(data, scale);
|
||||
// #else
|
||||
return fp8_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
|
||||
// #endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
|
||||
{
|
||||
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
|
||||
// would be better than the naive implementation below
|
||||
// #if CK_TILE_FP4_CVT_DEVICE
|
||||
// return impl::_from_f4<fp8x2_t>(data, scale);
|
||||
// #else
|
||||
return fp8x2_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
|
||||
type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
|
||||
// #endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
return e2m1_to_fp32_table[_unpack(number<0>{})] * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale,
|
||||
e2m1_to_fp32_table[_unpack(number<1>{})] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
return type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
return fp16x2_t{
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale),
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
|
||||
scale)};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
|
||||
{
|
||||
return type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
|
||||
{
|
||||
return fp8x2_t{
|
||||
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale),
|
||||
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
109
include/ck_tile/core/numeric/pk_fp6.hpp
Normal file
109
include/ck_tile/core/numeric/pk_fp6.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <index_t pk_size>
|
||||
struct pk_fp6_t
|
||||
{
|
||||
static constexpr index_t num_bits_elem = 6;
|
||||
using element_type = int32_t; // element storage fundamental type
|
||||
static constexpr index_t packed_size = pk_size;
|
||||
static constexpr index_t num_bits_vec_elem =
|
||||
sizeof(element_type) * 8; // 32-bit uint for storage
|
||||
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
|
||||
"Packed elements must fit exactly into the element storage.");
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
element_type data_[vector_size]; // packed data
|
||||
using type = pk_fp6_t<packed_size>;
|
||||
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0)
|
||||
{
|
||||
for(size_t i = 0; i < vector_size; ++i)
|
||||
{
|
||||
data_[i] = value;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE void pack(const int32_t x, const index_t i)
|
||||
{
|
||||
int32_t bits = static_cast<int32_t>(x) & 0x3F;
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_index = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
int32_t old_value = data_[arr_index];
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
data_[arr_index] = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
int32_t next_value = data_[arr_index + 1];
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
data_[arr_index + 1] = next_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static int32_t unpack(const T& pk, const index_t i)
|
||||
{
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
int32_t bits = pk.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return bits & 0x3F;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); }
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static float fp6_e2m3_to_float(int32_t fp6_bits)
|
||||
{
|
||||
fp6_bits = fp6_bits & 0x3F;
|
||||
|
||||
uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5
|
||||
uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3
|
||||
uint32_t mantissa = fp6_bits & 0x7; // bits 2-0
|
||||
|
||||
float result;
|
||||
if(exponent == 0 && mantissa == 0)
|
||||
{
|
||||
result = 0.f;
|
||||
}
|
||||
else if(exponent != 0)
|
||||
{
|
||||
result = std::exp2f(static_cast<int>(exponent) - 1);
|
||||
float mantissa_value = 1.0f + mantissa / 8.0f;
|
||||
result *= mantissa_value;
|
||||
}
|
||||
else
|
||||
{
|
||||
result = mantissa / 8.0f;
|
||||
}
|
||||
return sign == 1 ? -1 * result : result;
|
||||
}
|
||||
};
|
||||
|
||||
using pk_fp6x16_t = pk_fp6_t<16>;
|
||||
using pk_fp6x32_t = pk_fp6_t<32>;
|
||||
template <>
|
||||
struct numeric_traits<pk_fp6x16_t>
|
||||
{
|
||||
static constexpr int PackedSize = 16;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
200
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
200
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
@@ -0,0 +1,200 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Packed 2xint4
|
||||
struct pk_int4_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
|
||||
|
||||
// NOTE: added for interface compatibility with pk_fp4_t
|
||||
// Other data types could be added for greater similarity
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_int4_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
|
||||
{
|
||||
constexpr uint8_t val = 0b01110111;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_int4_t>
|
||||
{
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
#else
|
||||
fp32x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0);
|
||||
float x_h = ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
x_l = x_l > 7 ? x_l - 16 : x_l;
|
||||
x_h = x_h > 7 ? x_h - 16 : x_h;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
#else
|
||||
fp32x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
#else
|
||||
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
|
||||
#endif
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
|
||||
#else
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
int8_t x_l = (x_u8 & 0x0F);
|
||||
int8_t x_h = (x_u8 & 0xF0) >> 4;
|
||||
|
||||
if(x_l & 0x08)
|
||||
x_l |= 0xF0;
|
||||
if(x_h & 0x08)
|
||||
x_h |= 0xF0;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
int8x2_t res = {x_h, x_l};
|
||||
#else
|
||||
int8x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const
|
||||
{
|
||||
return pk_int4_t_to_fp32x2_t(*this);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
111
include/ck_tile/core/numeric/type_convert.hpp
Normal file
111
include/ck_tile/core/numeric/type_convert.hpp
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
|
||||
{
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#else
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
|
||||
|
||||
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
|
||||
float scale) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, scale); \
|
||||
} \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, 1.f); \
|
||||
}
|
||||
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
|
||||
#undef CK_TILE_SCALED_TYPE_CONVERT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
293
include/ck_tile/core/numeric/vector_type.hpp
Normal file
293
include/ck_tile/core/numeric/vector_type.hpp
Normal file
@@ -0,0 +1,293 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this structure is used to pick up the <base> type inside
|
||||
// using xxx = <base> __attribute__((ext_vector_type(N)));
|
||||
// because clang only allow native type + bool in this term (custom type will fail)
|
||||
// overload this structure to let proper <base> type
|
||||
|
||||
template <typename T>
|
||||
struct native_t
|
||||
{
|
||||
using type = remove_cvref_t<T>;
|
||||
};
|
||||
|
||||
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
|
||||
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
// have compiler error
|
||||
namespace impl {
|
||||
|
||||
template <typename T_, index_t N_, typename = void>
|
||||
struct ext_vector;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T, typename = void>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
remove_cvref_t<T>>>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
|
||||
{
|
||||
using scalar_type = std::conditional_t<
|
||||
std::is_same_v<T, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
T>>;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
using fp64_t = double;
|
||||
using fp64x2_t = double __attribute__((ext_vector_type(2)));
|
||||
using fp64x4_t = double __attribute__((ext_vector_type(4)));
|
||||
|
||||
// fp32
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp32x4_t = float __attribute__((ext_vector_type(4)));
|
||||
using fp32x8_t = float __attribute__((ext_vector_type(8)));
|
||||
using fp32x16_t = float __attribute__((ext_vector_type(16)));
|
||||
using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
// using fp16_t = ...
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x3_t = int32_t __attribute__((ext_vector_type(3)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
struct int32x3_tt
|
||||
{
|
||||
int32_t data[3];
|
||||
};
|
||||
|
||||
struct int32x6_tt
|
||||
{
|
||||
int32_t data[6];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<int8_t, 12>
|
||||
{
|
||||
static constexpr index_t N = 12;
|
||||
using value_type = int32x3_tt;
|
||||
using type = int32x3_tt;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<pk_fp6x16_t, 1>
|
||||
{
|
||||
static constexpr index_t N = 1;
|
||||
using value_type = int32x3_tt;
|
||||
using type = int32x3_tt;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<pk_fp6x16_t, 2>
|
||||
{
|
||||
static constexpr index_t N = 2;
|
||||
using value_type = int32x6_tt;
|
||||
using type = int32x6_tt;
|
||||
};
|
||||
|
||||
// u32
|
||||
// using uint32_t = ...
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
|
||||
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
|
||||
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
|
||||
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
|
||||
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
|
||||
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
|
||||
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
// using uint16_t
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
// using int8_t
|
||||
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute__((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute__((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute__((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute__((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// ui8
|
||||
// using uint8_t
|
||||
using uint8x2_t = uint8_t __attribute__((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute__((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute__((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute__((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute__((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64)));
|
||||
#else
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute__((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute__((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute__((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute__((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_t __attribute__((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute__((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute__((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute__((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute__((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute__((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
// pk_int4_t
|
||||
// using pk_int4_t
|
||||
using pk_int4x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
} // namespace ck_tile
|
||||
1327
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
1327
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
File diff suppressed because it is too large
Load Diff
212
include/ck_tile/core/tensor/load_tile.hpp
Normal file
212
include/ck_tile/core/tensor/load_tile.hpp
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load_with_offset(
|
||||
offset, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note This function is a modification of the existing load function.
|
||||
* It has been extended with two additional parameters: it takes a tuple as input
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should work with unknown number of params
|
||||
// Load element_wise API works only when the input type is a tuple-type
|
||||
return tile_windows[number<0>{}].load(
|
||||
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load_with_offset(
|
||||
offset, dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Loads a tile of data using inline assembly.
|
||||
*
|
||||
* @note Bare in mind that loading data this way, you have to manually initialize your
|
||||
* thread buffer and synchronize load afterwards in order to make sure it's done before
|
||||
* using loaded data from registers
|
||||
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
|
||||
* @see `buffer_load_fence()`
|
||||
*/
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
tile_window.async_load_with_offset(offset, lds_tile, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false>
|
||||
CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
async_load_tile_with_offset(lds_tile, tile_window, 0, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
template <typename T, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
532
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
532
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
@@ -0,0 +1,532 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr int DS_READ_TR_SIZE()
|
||||
{
|
||||
return 8; // Literal constant, evaluated at compile time
|
||||
}
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
{
|
||||
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
|
||||
|
||||
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
|
||||
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
|
||||
|
||||
static constexpr bool value =
|
||||
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
|
||||
};
|
||||
|
||||
template <index_t... Xs>
|
||||
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Suffix, typename Sequence>
|
||||
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
|
||||
|
||||
} // namespace util
|
||||
|
||||
// Default policy: Retains original 2D transpose behavior
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
template <index_t LaneGroupSize, index_t NumBitType>
|
||||
struct Quad
|
||||
{
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
|
||||
// The tile is defined by the LaneGroupSize, which defines the number of lanes in the M/N
|
||||
// dimensions for the MMA instruction defined by warp gemm.
|
||||
// The LaneGroupSize is subdivided into groups of 16 (finer granularity of MMA
|
||||
// instructions), we define these as major subtiles. Each of these major subtile is divided
|
||||
// into minor subtiles which group the lanes exchanging data during the transpose Example
|
||||
// LaneGroupSize = 16, 16 bit type:
|
||||
// - There is 1 group of 16 lanes (1 major subtile)
|
||||
// - Each major subtile is divided into 4 minor subtiles of (4x4) -> 4 lanes transpose
|
||||
// the minor subtile and each lane holds 4 elements
|
||||
|
||||
// all load transpose instructions use 64 bit right now
|
||||
static constexpr index_t InstructionBits = 64;
|
||||
// Subtile major dimension is fixed
|
||||
static constexpr index_t SubtileMajorDimension = 16;
|
||||
// Number of subtile major
|
||||
static constexpr index_t NumSubtilesMajor = LaneGroupSize / 16;
|
||||
// number of elements loaded by each lane with single instruction, but also number
|
||||
// of consecutive lanes in a subtile. Subtile is squared (NLanes x NElementsPerLane)
|
||||
static constexpr index_t SubtileMinorDimension = InstructionBits / NumBitType;
|
||||
// Number of subtiles minor inside each subtile major
|
||||
static constexpr index_t NumSubtilesMinor = 16 / SubtileMinorDimension;
|
||||
|
||||
using InputEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<SubtileMinorDimension>,
|
||||
sequence<NumSubtilesMajor, NumSubtilesMinor, SubtileMinorDimension>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<SubtileMinorDimension>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
static constexpr index_t PackedSize = numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t NumBitsDataType = (sizeof(DataType) * 8) / PackedSize;
|
||||
|
||||
// Select based on data size
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadInputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::InputEncoding;
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadOutputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::OutputEncoding;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
// Programmable: Element grouping function
|
||||
static constexpr auto group_func = [](auto idx) {
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
|
||||
|
||||
static constexpr auto input_ps_major_last =
|
||||
input_ps_major[number<input_ps_major.size() - 1>{}];
|
||||
static constexpr auto input_ps_minor_last =
|
||||
input_ps_minor[number<input_ps_minor.size() - 1>{}];
|
||||
|
||||
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
|
||||
input_hs[I1].size() - quad_hs[I1].size()>;
|
||||
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
|
||||
},
|
||||
number<quad_ps_minor0.size()>{});
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
|
||||
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
|
||||
decltype(input_ps_minor_last)>;
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
|
||||
"YS->RHS mapping must be single dimension");
|
||||
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
|
||||
"YS->RHS mapping must be the last dimension");
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection = false>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr bool value =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
|
||||
: 0;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
|
||||
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
|
||||
|
||||
static constexpr bool distr_encoding_valid = Validator::value;
|
||||
};
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>,
|
||||
bool ReverseDirection = false>
|
||||
struct TransposeTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
|
||||
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
|
||||
"The input tile distribution encoding is not valid for transpose!");
|
||||
|
||||
using QuadInputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
|
||||
using QuadOutputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
|
||||
|
||||
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr index_t dim0 = Policy::transpose_dims[0];
|
||||
static constexpr index_t dim1 = Policy::transpose_dims[1];
|
||||
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
|
||||
// for transpose load
|
||||
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
|
||||
// dims and append the quad_output_hs_lengthss to the end of each dimension
|
||||
static constexpr auto outer_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_hs_lengthss[i];
|
||||
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
|
||||
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
|
||||
static constexpr auto dst_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
auto outer_i = reversed_outer_hs_lengthss[i];
|
||||
// append the reversed quad output hs lengths to the outer hs lengths
|
||||
return outer_i.push_back(quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
|
||||
// thread distr) of the major sequence
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
|
||||
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences (i.e. warp), keep them unchanged
|
||||
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto quad_idx_offset =
|
||||
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
|
||||
|
||||
// minus 1 because RsLength is not counted
|
||||
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
|
||||
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_ps_to_rhss_minor[i];
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
|
||||
constexpr auto outer_ps =
|
||||
typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
|
||||
return outer_ps.push_back(quad_output_ps_minor_offset +
|
||||
quad_output_ps_to_rhss_minor0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_i;
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto dst_ys_to_rhs_major =
|
||||
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
|
||||
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
|
||||
|
||||
using TransposedDstrEncode =
|
||||
tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using OutputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using InputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
index_t kLeadNumWarps,
|
||||
index_t kSecondNumWarps>
|
||||
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
{
|
||||
constexpr auto block_outer_dst_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
|
||||
|
||||
return blk_distr_encode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* @param offset The offset (in elements) added to the base address before
|
||||
* indexing.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE void load_tile_transpose_with_offset(
|
||||
DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
|
||||
|
||||
// Check that the tile distribution of out_tensor is the expected one for transposed loads.
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
static_assert(std::is_same_v<decltype(make_static_tile_distribution(OutTileDstrEncode{})),
|
||||
remove_cvref_t<decltype(output_distr)>>);
|
||||
|
||||
// Check that the datatype of out_tensor matches that of the bottom tensor view.
|
||||
static_assert(std::is_same_v<typename DistributedTensor_::DataType,
|
||||
typename BottomTensorView_::DataType>);
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"the element space size is not the same!");
|
||||
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
out_tensor.get_thread_buffer().template set_as<DataVec>(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* indexing.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE void
|
||||
load_tile_transpose(DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
}
|
||||
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_tensor
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
100
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
100
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// placeholder type if we want to opt-out a tile window parameter
|
||||
template <typename WindowLengths_>
|
||||
struct null_tile_window
|
||||
{
|
||||
using BottomTensorView = null_tensor_view;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
|
||||
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
|
||||
: window_lengths_{window_lengths}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
|
||||
|
||||
CK_TILE_DEVICE void init_raw() {}
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
};
|
||||
|
||||
// utility to check if this is a Null Tile Window
|
||||
namespace impl {
|
||||
template <typename>
|
||||
struct is_null_tile_window : public std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return is_null_tile_window_v<remove_cvref_t<T>>;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename... Ts>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<WindowLengths::size()>& /*origin*/,
|
||||
Ts&&...)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window<WindowLengths>& t,
|
||||
const StaticTileDistribution&)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(null_tile_window<WindowLengths>&,
|
||||
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
180
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
180
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// y_dim_out_to_in
|
||||
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
|
||||
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
|
||||
|
||||
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
|
||||
});
|
||||
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// using InVec = typename InVec::type;
|
||||
// using OutVec = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
|
||||
// shuffle
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
|
||||
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY)
|
||||
{
|
||||
detail::shuffle_tile_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "The shuffle should always happen!");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
94
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
94
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
|
||||
// NOTE: This API will override the origin of the tile window!
|
||||
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
|
||||
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
|
||||
|
||||
constexpr auto slice_lengths = slice_ends - slice_begins;
|
||||
|
||||
return make_tile_window(tile.get_bottom_tensor_view(),
|
||||
sequence_to_tuple_of_number(slice_lengths),
|
||||
to_multi_index(slice_begins));
|
||||
}
|
||||
|
||||
template <typename DataType_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Distribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
|
||||
|
||||
sliced_tensor.get_thread_buffer() =
|
||||
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
|
||||
|
||||
return sliced_tensor;
|
||||
}
|
||||
|
||||
template <typename DstDataType_,
|
||||
typename DstStaticTileDistribution_,
|
||||
typename SrcDataType_,
|
||||
typename SrcStaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
|
||||
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
using SrcDistribution = remove_cvref_t<SrcStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<decltype(sliced_dstr)>, SrcDistribution>, "wrong!");
|
||||
|
||||
dst_tile.set_y_sliced_thread_data(
|
||||
sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
273
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
273
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
@@ -0,0 +1,273 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
static_assert(StaticTileDistribution::is_static(),
|
||||
"wrong! StaticTileDistribution should be known at compile tile");
|
||||
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
|
||||
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
|
||||
{
|
||||
return StaticTileDistribution::get_num_of_dimension_x();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
return StaticTileDistribution::get_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
|
||||
{
|
||||
return StaticTileDistribution{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
return StaticTileDistribution::get_distributed_spans();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths>
|
||||
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>) const
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
// divide element number by PackedSize to get the correct thread buffer size
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
sliced_thread_data(
|
||||
number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
|
||||
});
|
||||
|
||||
return sliced_thread_data;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
|
||||
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const SlicedThreadData& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
|
||||
PackedSize>{}];
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
|
||||
}
|
||||
|
||||
//
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{};
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
|
||||
ThreadBuffer&& thread_buffer_)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index)
|
||||
{
|
||||
constexpr auto y_indices =
|
||||
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
|
||||
|
||||
const auto x_coord = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
|
||||
|
||||
return x_coord.get_bottom_index();
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
{
|
||||
return get_x_indices_from_distributed_indices(
|
||||
tile_distribution, distributed_indices, get_partition_index(tile_distribution));
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
|
||||
distributed_indices);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate,
|
||||
decltype(get_partition_index(std::declval<StaticTileDistribution>())) partition_index)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution{}, distributed_indices, partition_index);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// this function used inside span loop over
|
||||
template <typename YLengths, index_t XUnpacks>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
|
||||
{
|
||||
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies<>{}, number<1>{});
|
||||
constexpr auto y_packs = number<XUnpacks>{};
|
||||
static_assert(y_size % y_packs == 0);
|
||||
constexpr auto y_slice_size = y_size / y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// check if 2 static_distributed_tensor has same data type and size of element
|
||||
// but only difference in distribution
|
||||
template <typename X, typename Y>
|
||||
struct is_similiar_distributed_tensor
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
|
||||
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
|
||||
static_distributed_tensor<TypeY, DistY>>
|
||||
{
|
||||
using Tx = static_distributed_tensor<TypeX, DistX>;
|
||||
using Ty = static_distributed_tensor<TypeY, DistY>;
|
||||
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
|
||||
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_similiar_distributed_tensor_v =
|
||||
is_similiar_distributed_tensor<X, Y>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
171
include/ck_tile/core/tensor/store_tile.hpp
Normal file
171
include/ck_tile/core/tensor/store_tile.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr,
|
||||
partition_index);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr,
|
||||
partition_index);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
308
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
308
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
@@ -0,0 +1,308 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// sweep over a span of a distribted tile and apply lambda function F
|
||||
template <typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F // signature: F(tile_distributed_index<...>)
|
||||
>
|
||||
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
|
||||
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
|
||||
|
||||
f(dstr_idx);
|
||||
});
|
||||
}
|
||||
|
||||
// unpacked span, this version support span with unpack(multi-arg) functor
|
||||
//
|
||||
template <
|
||||
typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F, // signature: F(tile_distributed_index<...>)
|
||||
typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
|
||||
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_uford<typename DstrSpan::Impl, Unpacks>{}(
|
||||
[&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl;
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) {
|
||||
return make_tuple(concat_tuple(
|
||||
si, make_tuple(detail::make_tile_distributed_index(i_idx)))...);
|
||||
},
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl_0;
|
||||
|
||||
// TODO: support empty tuple to remove this "entry-point" like function
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx =
|
||||
make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
/*
|
||||
* Enhanced sweep-tile utility, can control unpacks along each X-dim
|
||||
* the lambda function argument is the distributed-idx, which can directly
|
||||
* plugged into the distributed tensor as setter/getter
|
||||
*
|
||||
* e.g. below function, y with the type DistributedTensor, r is row scale
|
||||
*
|
||||
* // sweep tile 1 by 1
|
||||
* sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
* y(idx) = y(idx) * r(row_id);
|
||||
* });
|
||||
*
|
||||
* // sweep tile with 2 pixel from last dim each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_0, auto idx_1) {
|
||||
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
|
||||
* y(idx_0) = y(idx_0) * r(row_id);
|
||||
* y(idx_1) = y(idx_1) * r(row_id);
|
||||
* },
|
||||
* sequence<1, 2>{});
|
||||
*
|
||||
* // sweep tile with 2x2 pixel each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
|
||||
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
|
||||
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
|
||||
* y(idx_00) = y(idx_00) * r(row_id0);
|
||||
* y(idx_01) = y(idx_01) * r(row_id0);
|
||||
* y(idx_10) = y(idx_10) * r(row_id1);
|
||||
* y(idx_11) = y(idx_11) * r(row_id1);
|
||||
* },
|
||||
* sequence<2, 2>{});
|
||||
*
|
||||
* TODO: do we need constexpr? lambda function could be non-constexpr
|
||||
*/
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
/*
|
||||
* construct a sweep tile instance, which support issue the lambda one by one
|
||||
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
|
||||
* the functionality is the same as sweep_tile()
|
||||
*/
|
||||
template <typename DistributedTensor_,
|
||||
typename F_,
|
||||
typename UnpacksPerXDim_ =
|
||||
typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
|
||||
struct tile_sweeper
|
||||
{
|
||||
using DistributedTensor = remove_cvref_t<DistributedTensor_>;
|
||||
using F = remove_cvref_t<F_>;
|
||||
using UnpacksPerXDim = remove_cvref_t<UnpacksPerXDim_>;
|
||||
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {}
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {})
|
||||
: f(f_)
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto tmp =
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
|
||||
return tmp.get_num_of_access();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void operator()() const
|
||||
{
|
||||
sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
template <index_t i_access>
|
||||
CK_TILE_HOST_DEVICE void operator()(number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
|
||||
f, number<i_access>{});
|
||||
}
|
||||
F f;
|
||||
};
|
||||
|
||||
// partial deduction is not allowed
|
||||
// template <typename T, typename F, typename U>
|
||||
// tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
// deduction guide
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
956
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
956
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,956 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
|
||||
{
|
||||
return Transforms::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of number and index_t
|
||||
return container_reduce(lengths, multiplies<>{}, number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == IDimHidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
less<index_t>,
|
||||
equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
|
||||
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
|
||||
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
// may be index_t or number<>
|
||||
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
|
||||
|
||||
// FIXME: this logic is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
|
||||
{
|
||||
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
}
|
||||
|
||||
template <index_t IDimTop>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimBottom>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t
|
||||
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
|
||||
number<ndim_top_>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
|
||||
number<ndim_bottom_>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TopIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = get_transforms().at(itran);
|
||||
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::size(), 1>{}([&](auto i) {
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
|
||||
});
|
||||
|
||||
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
template <index_t Internal = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
|
||||
{
|
||||
auto vector_lengths = guaranteed_vector_lengths;
|
||||
auto vector_strides = guaranteed_vector_strides;
|
||||
|
||||
static_for<0,
|
||||
Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
|
||||
1>{}([&](auto itran) {
|
||||
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto up_guaranteed_vector_lengths =
|
||||
get_container_subset(guaranteed_vector_lengths, up_dims);
|
||||
const auto up_guaranteed_vector_strides =
|
||||
get_container_subset(guaranteed_vector_strides, up_dims);
|
||||
|
||||
// only need type of transform
|
||||
auto [up_vector_lengths, up_vector_strides] =
|
||||
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
|
||||
get_container_subset(vector_lengths, low_dims),
|
||||
get_container_subset(vector_strides, low_dims));
|
||||
|
||||
if constexpr(up_dims.size() > 0)
|
||||
{
|
||||
for(index_t i = 0; i < up_dims.size(); ++i)
|
||||
{
|
||||
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
|
||||
? up_guaranteed_vector_lengths[i]
|
||||
: up_vector_lengths[i];
|
||||
|
||||
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
|
||||
? up_guaranteed_vector_strides[i]
|
||||
: up_vector_strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
|
||||
set_container_subset(vector_strides, up_dims, up_vector_strides);
|
||||
});
|
||||
if constexpr(Internal > 0)
|
||||
{
|
||||
return make_tuple(vector_lengths, vector_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
BottomDimensionHiddenIds,
|
||||
TopDimensionHiddenIds>& adaptor)
|
||||
{
|
||||
printf("tensor_adaptor{\n");
|
||||
printf(" transforms: [");
|
||||
print(adaptor.get_transforms());
|
||||
printf("],\n");
|
||||
|
||||
printf(" LowerDimensionHiddenIds: [");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" UpperDimensionHiddenIds: [");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" BottomDimensionHiddenIds: [");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf("],\n");
|
||||
|
||||
//
|
||||
printf(" TopDimensionHiddenIds: [");
|
||||
print(TopDimensionHiddenIds{});
|
||||
printf("]\n}\n");
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
|
||||
number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
|
||||
|
||||
return tensor_adaptor<remove_cvref_t<Transforms>,
|
||||
remove_cvref_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_tensor_adaptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
{
|
||||
template <typename I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
|
||||
{
|
||||
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
|
||||
return number<Tran::get_num_of_upper_dimension()>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorAdaptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
|
||||
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_top_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension top id to hidden id
|
||||
[](auto low_dim_top_id) constexpr {
|
||||
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
|
||||
},
|
||||
low_dim_top_ids);
|
||||
},
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::size();
|
||||
|
||||
// upper dimension's hidden idss
|
||||
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
|
||||
|
||||
constexpr auto up_dim_numbers =
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr auto new_top_dim_hidden_ids =
|
||||
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms =
|
||||
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
|
||||
|
||||
return tensor_adaptor<
|
||||
remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
|
||||
TensorAdaptor1::get_num_of_bottom_dimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id_ =
|
||||
min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()
|
||||
[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_low_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_up_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// Macro function
|
||||
// construct constexpr tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pass_through_transform(low_len); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
auto left_pad = meta_data.template pop<index_t>(pos); \
|
||||
auto right_pad = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pad_transform(low_len, left_pad, right_pad); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
auto coefficients = \
|
||||
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_embed_transform(up_lens, coefficients); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
|
||||
\
|
||||
return make_merge_transform(low_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_unmerge_transform(up_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_replicate_transform(up_lens); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
|
||||
// Macro function
|
||||
// construct static tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
return make_pass_through_transform(number<low_len>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
constexpr index_t left_pad = \
|
||||
meta_data.template get<index_t>(sizeof(low_len)); \
|
||||
\
|
||||
constexpr index_t right_pad = \
|
||||
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
|
||||
\
|
||||
return make_pad_transform( \
|
||||
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
constexpr auto coefficients = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
|
||||
\
|
||||
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
|
||||
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
constexpr auto low_lens = \
|
||||
meta_data.template get<array<index_t, num_low_dim>>(0); \
|
||||
\
|
||||
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
#pragma clang diagnostic pop
|
||||
373
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
373
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor_coordinate
|
||||
{
|
||||
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
|
||||
|
||||
//
|
||||
HiddenIndex idx_hidden_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
|
||||
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize visible index
|
||||
set_container_subset(idx_hidden, top_dim_ids, idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return tensor_adaptor_coordinate<ndim_hidden,
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true,
|
||||
typename Adaptor,
|
||||
typename AdaptorCoord,
|
||||
typename TopIndex,
|
||||
typename BottomIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top,
|
||||
BottomIndex& idx_diff_bottom)
|
||||
{
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
|
||||
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
|
||||
|
||||
// judge whether calculation of lower diff is needed for each transform
|
||||
// use index_t for boolean type
|
||||
auto do_transforms = make_zero_multi_index<ntransform>();
|
||||
|
||||
if constexpr(JudgeDoTransforms)
|
||||
{
|
||||
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// decide do_transform by checkout non-zero index diff components
|
||||
multi_index<ndim_top> non_zero_diff_pick_top;
|
||||
|
||||
static_for<0, ndim_top, 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
|
||||
|
||||
set_container_subset(
|
||||
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> non_zero_diff_pick_low;
|
||||
|
||||
// if any of upper index diff components is non-zero, then
|
||||
// 1) Need to do this transform
|
||||
// 2) all components of lower index diff will assume to be non-zero and need to be
|
||||
// computed
|
||||
const bool idx_diff_up_has_non_zero = container_reduce(
|
||||
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
|
||||
|
||||
do_transforms(itran) = idx_diff_up_has_non_zero;
|
||||
|
||||
static_for<0, dims_low.size(), 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
|
||||
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
|
||||
}
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize top index diff
|
||||
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
// update top index
|
||||
auto idx_hidden_pick_top =
|
||||
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
|
||||
|
||||
idx_hidden_pick_top += idx_diff_top;
|
||||
|
||||
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(do_transforms[itran])
|
||||
{
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
|
||||
auto idx_low = get_container_subset(idx_hidden, dims_low);
|
||||
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_diff_low;
|
||||
|
||||
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
|
||||
|
||||
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
}
|
||||
});
|
||||
|
||||
// set bottom index diff
|
||||
idx_diff_bottom =
|
||||
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top)
|
||||
{
|
||||
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
|
||||
multi_index<ndim_bottom> tmp;
|
||||
|
||||
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdaptorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
|
||||
const AdaptorCoord& coord)
|
||||
{
|
||||
bool valid = true;
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
const auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
|
||||
const auto tran = adaptor.get_transforms().at(itran);
|
||||
|
||||
// check validity, only if current transformation does not always has a valid mapping
|
||||
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
|
||||
{
|
||||
const auto idx_up = get_container_subset(
|
||||
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
|
||||
|
||||
// Comment: using valid = valid && .. will result in weird control flow in ISA
|
||||
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
|
||||
}
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdpatorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
|
||||
const AdpatorCoord& coord)
|
||||
{
|
||||
// check top index
|
||||
const auto& idx_top = coord.get_top_index();
|
||||
|
||||
bool is_top_index_valid = true;
|
||||
|
||||
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
|
||||
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
|
||||
is_top_index_valid =
|
||||
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
|
||||
});
|
||||
|
||||
// check other hidden index
|
||||
return is_top_index_valid &&
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename PREFIX = str_literal<>, typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINT_X_;
|
||||
|
||||
template <char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINT_X_<str_literal<PREFIXChars...>, str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
struct detail;
|
||||
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
|
||||
struct detail<
|
||||
tensor_adaptor_coordinate<NDimHidden, BottomDimensionHiddenIds, TopDimensionHiddenIds>>
|
||||
{
|
||||
using coord_t =
|
||||
tensor_adaptor_coordinate<NDimHidden, BottomDimensionHiddenIds, TopDimensionHiddenIds>;
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format_i()
|
||||
{
|
||||
constexpr bool is_bottom =
|
||||
sequence_any_of(BottomDimensionHiddenIds{}, [](auto b) { return b == I; });
|
||||
constexpr bool is_top =
|
||||
sequence_any_of(TopDimensionHiddenIds{}, [](auto t) { return t == I; });
|
||||
constexpr auto d = make_str_literal("%d");
|
||||
if constexpr(is_bottom && is_top)
|
||||
return make_str_literal("_^") + d;
|
||||
else if constexpr(is_bottom)
|
||||
return make_str_literal("_") + d;
|
||||
else if constexpr(is_top)
|
||||
return make_str_literal("^") + d;
|
||||
else
|
||||
return d;
|
||||
}
|
||||
template <index_t N = NDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format()
|
||||
{
|
||||
constexpr auto sep = make_str_literal(" ");
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else
|
||||
return get_hidden_format<N - 1>() + sep + get_hidden_format_i<N - 1>();
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_format()
|
||||
{
|
||||
constexpr auto d = make_str_literal("%d");
|
||||
constexpr auto sep = make_str_literal(" ");
|
||||
constexpr auto bottom_fmt =
|
||||
d.template duplicate_n<BottomDimensionHiddenIds::size()>(sep);
|
||||
constexpr auto top_fmt = d.template duplicate_n<TopDimensionHiddenIds::size()>(sep);
|
||||
constexpr auto hidden_fmt = get_hidden_format();
|
||||
return make_str_literal("[ __") + bottom_fmt + make_str_literal("__ | ^^") + top_fmt +
|
||||
make_str_literal("^^ | ") + hidden_fmt + make_str_literal(" ]");
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_values()
|
||||
{
|
||||
return BottomDimensionHiddenIds::size() + TopDimensionHiddenIds::size() + NDimHidden;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_values(const coord_t& coord)
|
||||
{
|
||||
return container_concat(
|
||||
coord.get_bottom_index(), coord.get_top_index(), coord.get_hidden_index());
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <char... FMTChars, typename TArgs, index_t... Is, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void impl(str_literal<FMTChars...>,
|
||||
const TArgs& targs,
|
||||
std::integer_sequence<index_t, Is...>,
|
||||
Args&&... args) const
|
||||
{
|
||||
constexpr auto fmt_wrap_v = get_prefix() + str_literal<FMTChars...>{} + get_suffix();
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data, get_thread_id(), args..., targs.at(number<Is>())...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
template <typename T, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(T&& x, Args&&... args) const
|
||||
{
|
||||
using detail_t = detail<remove_cvref_t<T>>;
|
||||
impl(detail_t::get_format(),
|
||||
detail_t::get_values(std::forward<T>(x)),
|
||||
std::make_integer_sequence<index_t, (detail_t::get_num_values())>{},
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N, typename B, typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate<N, B, T>& coord)
|
||||
{
|
||||
detail::CK_PRINT_X_<>{}(coord);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
97
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
97
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename TopDimensionHiddenIds>
|
||||
struct tensor_coordinate
|
||||
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
|
||||
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: Base{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
// construct from TensorAdaptorCoordinte base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
|
||||
{
|
||||
return Base::get_bottom_index()[number<0>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
|
||||
{
|
||||
return Base::get_hidden_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
|
||||
};
|
||||
|
||||
template <typename TensorDesc, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
|
||||
|
||||
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
|
||||
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
|
||||
adaptor_coord};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
|
||||
{
|
||||
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <index_t N, typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const tensor_coordinate<N, T>& coord)
|
||||
{
|
||||
print(static_cast<typename tensor_coordinate<N, T>::Base>(coord));
|
||||
}
|
||||
} // namespace ck_tile
|
||||
507
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
507
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
@@ -0,0 +1,507 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// TopDimensionHiddenIds> : sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths_,
|
||||
typename GuaranteedVectorSrides_>
|
||||
struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>;
|
||||
|
||||
using ElementSpaceSizeType = ElementSpaceSize;
|
||||
|
||||
constexpr static index_t ntransform_ = Base::get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
|
||||
|
||||
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
|
||||
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
|
||||
|
||||
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
|
||||
GuaranteedVectorStrides::size() == ndim_hidden_,
|
||||
"wrong! inconsistent # of hidden dimensions");
|
||||
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{transforms}, element_space_size_{element_space_size}
|
||||
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
// construct from tensor_adaptor base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{adaptor}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return Base::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
|
||||
{
|
||||
return Base::get_top_dimension_length(idim);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
{
|
||||
return element_space_size_;
|
||||
}
|
||||
|
||||
template <typename Idx>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
|
||||
{
|
||||
return Base::calculate_bottom_index(idx)[number<0>{}];
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
|
||||
{
|
||||
return Base::get_transforms();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_lower_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_upper_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return Base::get_top_dimension_hidden_ids();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return Base::is_known_at_compile_time() &&
|
||||
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
template <index_t Internal = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
|
||||
{
|
||||
return Base::template get_top_dimension_safe_vector_length_strides<Internal>(
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths,
|
||||
typename GuaranteedVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>& descriptor)
|
||||
{
|
||||
printf("tensor_descriptor{\n");
|
||||
// first print the tensor adaptor part of the descriptor using the base class print
|
||||
using Base = typename tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>::Base;
|
||||
print(static_cast<const Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
|
||||
printf("guaranteed_vector_lengths: ");
|
||||
print(GuaranteedVectorLengths{});
|
||||
printf(",\nguaranteed_vector_strides: ");
|
||||
print(GuaranteedVectorStrides{});
|
||||
printf("}\n}\n");
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
const ElementSpaceSize& element_space_size)
|
||||
{
|
||||
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
|
||||
|
||||
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type>{
|
||||
adaptor, element_space_size};
|
||||
}
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
const auto element_space_size = old_tensor_desc.get_element_space_size();
|
||||
|
||||
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldTopIdss{},
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
|
||||
|
||||
using NewGuaranteedVectorLengths = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorLengths,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
using NewGuaranteedVectorStrides = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorStrides,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
return tensor_descriptor<
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
NewGuaranteedVectorLengths,
|
||||
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < Lengths::size() - 1)
|
||||
{
|
||||
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*
|
||||
* These functions create naive tensor descriptor
|
||||
*/
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size =
|
||||
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
// tensor descriptor with offset, the offset will not be added into element space size
|
||||
// only have an information of the starting offset, and will impact on offset calculation
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
const offset& os,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_embed_transform(lengths, strides)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_unmerge_transform(lengths));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
|
||||
|
||||
constexpr index_t first_dim_length = []() {
|
||||
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
|
||||
return decltype(element_space_size)::value;
|
||||
else
|
||||
return -1;
|
||||
}();
|
||||
using last_t = remove_cvref_t<decltype(lengths.template get<N - 1>())>;
|
||||
constexpr index_t last_dim_length = []() {
|
||||
if constexpr(is_constant_v<last_t>)
|
||||
return std::max(last_t::value, GuaranteedLastDimensionVectorLength);
|
||||
else
|
||||
return -1;
|
||||
}();
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<sequence<first_dim_length>,
|
||||
typename uniform_sequence_gen<N - 1, -1>::type,
|
||||
sequence<last_dim_length>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<sequence<1>,
|
||||
typename uniform_sequence_gen<N - 1, -1>::type,
|
||||
sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename Offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
|
||||
const tuple<Lengths...>& lengths,
|
||||
const Offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_unmerge_transform(lengths)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
multiplies<>{},
|
||||
number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
number<N - 1>{},
|
||||
I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
611
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
611
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
@@ -0,0 +1,611 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* tensor_view
|
||||
* abstract the underneath memory buffer(global, LDS, etc...)
|
||||
* and provide a unified get/set function for access
|
||||
*
|
||||
* For addressing into the buffer we use 2 variable to control:
|
||||
* coord : ND tensor coordinate, will calculate the actual offset inside
|
||||
* linear_offset : 1D offset, will be used in the immediate field of
|
||||
* the buffer instruction to help reduce register usage
|
||||
*
|
||||
* User can use either of the field, or both to indexing into the tensor
|
||||
*
|
||||
* We usually provide 2 set of API for buffer get/set, e.g.
|
||||
* get_vectorized_elements()/get_vectorized_elements_raw()
|
||||
* the former usually will call intrinsic or normal C function, the later
|
||||
* usually will call inline-asm function
|
||||
*
|
||||
*/
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
using DataType = typename buffer_view::type;
|
||||
using DataType_ = remove_cvref_t<DataType>;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<DataType_>::PackedSize;
|
||||
|
||||
template <typename T>
|
||||
using vector_scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
|
||||
const TensorDesc& desc)
|
||||
: buf_{buffer_view}, desc_{desc}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return TensorDesc::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element, // flag
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(dst,
|
||||
coord.get_offset() /
|
||||
PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
index_t IMM = 0,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType_>>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
index_t offset,
|
||||
index_t wave_offset,
|
||||
number<IMM> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
offset / PackedSize,
|
||||
wave_offset,
|
||||
number<IMM / PackedSize>{},
|
||||
true,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType_>>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize + linear_offset / PackedSize,
|
||||
0,
|
||||
0, // linear_offset need to be imm and is not supported currently
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
0,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t coord_extra_offset,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
(coord.get_offset() + coord_extra_offset) / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// member
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
};
|
||||
|
||||
// placeholder type if we want to opt-out a tile view parameter
|
||||
struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_tensor_view : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename BufferView, typename TensorDesc, memory_operation_enum DstInMemOp>
|
||||
struct is_tensor_view<tensor_view<BufferView, TensorDesc, DstInMemOp>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_tensor_view<null_tensor_view> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_tensor_view_v = is_tensor_view<T>::value;
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view_packed(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
|
||||
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <typename OldTensorView,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::buffer_view,
|
||||
remove_cvref_t<decltype(new_desc)>,
|
||||
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename TileLengths, // tuple<...>
|
||||
typename DoPads> // sequence<bool, bool, ...>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
// transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto idim) {
|
||||
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
|
||||
|
||||
const auto tile_length = tile_lengths[idim];
|
||||
|
||||
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
|
||||
const auto pad_length = new_length - old_length;
|
||||
|
||||
constexpr bool DoPad = DoPads::at(idim);
|
||||
|
||||
const auto transform =
|
||||
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
|
||||
make_pass_through_transform(old_length));
|
||||
|
||||
return transform;
|
||||
},
|
||||
number<num_dim>{});
|
||||
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
737
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
737
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
@@ -0,0 +1,737 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
return Distribution::get_partition_index();
|
||||
}
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
struct tile_distributed_span
|
||||
{
|
||||
using Impl = sequence<PartialHsLengths...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
// distributed index
|
||||
template <index_t... PartialHsIndices>
|
||||
struct tile_distributed_index
|
||||
{
|
||||
using Impl = sequence<PartialHsIndices...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_span<Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_index<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
|
||||
// should be more elegnat
|
||||
struct tile_distribution
|
||||
{
|
||||
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
|
||||
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
|
||||
|
||||
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
|
||||
"wrong! should be static");
|
||||
|
||||
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
|
||||
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
|
||||
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
|
||||
|
||||
PsYs2XsAdaptor ps_ys_to_xs_;
|
||||
Ys2DDescriptor ys_to_d_;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto get_partition_index()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
#if 0
|
||||
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
|
||||
ps_ys_to_xs_.GetBottomDimensionLengths();
|
||||
#else
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies<>{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
|
||||
{
|
||||
return ps_ys_to_xs_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
|
||||
{
|
||||
return DstrEncode{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
// Calculate Replication index [R0, R1, ...] based on Partion index
|
||||
// FIXME: very nasty implementation
|
||||
template <typename PartitionIndex>
|
||||
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
|
||||
{
|
||||
static_assert(PartitionIndex::size() == NDimP, "wrong!");
|
||||
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
|
||||
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
|
||||
array<index_t, NDimR> rs_idx;
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
|
||||
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
|
||||
|
||||
// 0-th rh_major is the replicate dimension
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
constexpr index_t adaptor_hidden_id =
|
||||
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
|
||||
|
||||
// fill in
|
||||
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rs_idx;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename PartitionIndex = decltype(get_partition_index())>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const
|
||||
{
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
const auto window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
return window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
|
||||
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto span_impl = distributed_spans_impl[i];
|
||||
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
|
||||
|
||||
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
|
||||
|
||||
return detail::make_tile_distributed_span(span);
|
||||
},
|
||||
number<NDimX>{});
|
||||
}
|
||||
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
|
||||
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
|
||||
|
||||
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
|
||||
|
||||
ys_idx(i) = dstr_index.impl_[span_minor];
|
||||
});
|
||||
|
||||
return ys_idx;
|
||||
}();
|
||||
|
||||
constexpr index_t ndim_y = NDimY;
|
||||
|
||||
return TO_SEQUENCE(ys_idx_arr, ndim_y);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_tile_distribution : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename PsYs2XsAdaptor,
|
||||
typename Ys2DDescriptor,
|
||||
typename StaticTileDistributionEncoding,
|
||||
typename TileDistributionDetail>
|
||||
struct is_tile_distribution<tile_distribution<PsYs2XsAdaptor,
|
||||
Ys2DDescriptor,
|
||||
StaticTileDistributionEncoding,
|
||||
TileDistributionDetail>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_tile_distribution_v = is_tile_distribution<T>::value;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t NDimMax>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
|
||||
{
|
||||
array<index_t, NDimMax> arr{0};
|
||||
|
||||
for(index_t i = 0; i < iend - ibegin; ++i)
|
||||
{
|
||||
arr(i) = ibegin + i;
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
|
||||
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
|
||||
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
|
||||
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
|
||||
|
||||
// FIXME: increase max value if fail
|
||||
constexpr index_t kMaxNumTransforms = 20;
|
||||
constexpr index_t kMaxMetaDataSize = 128;
|
||||
constexpr index_t kMaxNumDim = 10;
|
||||
|
||||
using Name = coord_transform_enum;
|
||||
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
|
||||
using NumDim = index_t;
|
||||
using Dims = array<index_t, kMaxNumDim>;
|
||||
using Lengths = array<index_t, kMaxNumDim>;
|
||||
|
||||
// Tile Adaptor
|
||||
// bottom dims [x0, x1, x2, ...]
|
||||
// top dims [p0, p1, ..., y0, y1, ...]
|
||||
constexpr index_t ndim_x = HsLengthss::size();
|
||||
|
||||
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
|
||||
|
||||
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
|
||||
|
||||
index_t num_tran = 0;
|
||||
index_t hidden_dim_cnt = ndim_x;
|
||||
|
||||
// this is replicate transform
|
||||
{
|
||||
constexpr index_t ndim_r_minor = RsLengths::size();
|
||||
|
||||
constexpr auto r_minor_lengths = RsLengths{};
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::replicate,
|
||||
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
|
||||
NumDim{0},
|
||||
Dims{},
|
||||
NumDim{ndim_r_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_r_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
};
|
||||
|
||||
// these are Unmerge transforms for X dimesions
|
||||
static_for<0, ndim_x, 1>{}([&trans,
|
||||
&num_tran,
|
||||
&hidden_dim_cnt,
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths =
|
||||
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
|
||||
NumDim{1},
|
||||
Dims{idim_x},
|
||||
NumDim{ndim_h_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_h_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
});
|
||||
|
||||
// transform: P dimensions
|
||||
constexpr index_t ndim_p = Ps2RHssMajor::size();
|
||||
|
||||
Dims hidden_dim_id_ps;
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
|
||||
//
|
||||
index_t hidden_dim_id_p = hidden_dim_cnt++;
|
||||
|
||||
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
|
||||
|
||||
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
|
||||
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
|
||||
|
||||
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
|
||||
|
||||
constexpr index_t ndim_low = p2RHsMajor.size();
|
||||
|
||||
Dims low_dims;
|
||||
Lengths low_lengths;
|
||||
|
||||
for(index_t i = 0; i < ndim_low; ++i)
|
||||
{
|
||||
index_t rh_major = p2RHsMajor[i];
|
||||
index_t rh_minor = p2RHsMinor[i];
|
||||
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
trans(num_tran++) = {coord_transform_enum::merge,
|
||||
MetaData{to_array<index_t, ndim_low>(low_lengths)},
|
||||
NumDim{ndim_low},
|
||||
low_dims,
|
||||
NumDim{1},
|
||||
Dims{hidden_dim_id_p}};
|
||||
});
|
||||
|
||||
constexpr index_t ndim_bottom = ndim_x;
|
||||
|
||||
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
|
||||
|
||||
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
|
||||
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_y = Ys2RHsMajor::size();
|
||||
constexpr index_t ndim_top = ndim_p + ndim_y;
|
||||
|
||||
auto top_dim_ids = hidden_dim_id_ps;
|
||||
|
||||
{
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
const auto ps_ys_to_xs_adaptor_encoding =
|
||||
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
|
||||
|
||||
// descriptor: [y0, y1, ...] to [d]
|
||||
Lengths y_lengths;
|
||||
index_t d_length = 1;
|
||||
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
y_lengths(i) = y_length;
|
||||
d_length *= y_length;
|
||||
}
|
||||
|
||||
auto tran = make_tuple(coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_y>(y_lengths)},
|
||||
NumDim{1},
|
||||
Dims{0},
|
||||
NumDim{ndim_y},
|
||||
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
|
||||
|
||||
const auto ys_to_d_adaptor_encoding = make_tuple(
|
||||
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
|
||||
|
||||
return make_tuple(ps_ys_to_xs_adaptor_encoding,
|
||||
ys_to_d_adaptor_encoding,
|
||||
d_length,
|
||||
rh_major_minor_to_hidden_ids);
|
||||
}
|
||||
|
||||
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
|
||||
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
|
||||
struct tile_distribution_detail
|
||||
{
|
||||
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
|
||||
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#if 0
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
#endif
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
//***********************************************************************************
|
||||
|
||||
namespace detail {
|
||||
//
|
||||
// slice tensor from x_dim, result in split in y_dim, not p_dim.
|
||||
// We don't support slice cross p_dim (aka, slice different threads)
|
||||
// also, sliced along y_dim need be the first dim of current dim.
|
||||
// Multiply Y dim before sliced dim does not make sense
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
// subdime
|
||||
// the P dim in the left is 1, means actually not crossing P
|
||||
//
|
||||
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
|
||||
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
|
||||
{
|
||||
// NOTE: this function need to be called under constexpr context,
|
||||
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
|
||||
|
||||
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
|
||||
|
||||
constexpr auto x_slice_ends_ = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(x_slice_ends[i] == -1)
|
||||
{
|
||||
// -1 means till the end
|
||||
constexpr auto x_length_ = container_reduce(
|
||||
typename Encoding::HsLengthss{}[i], multiplies<>{}, number<1>{});
|
||||
return x_length_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return x_slice_ends[i];
|
||||
}
|
||||
},
|
||||
number<x_slice_ends.size()>{});
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
|
||||
|
||||
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr auto len_ = x_slice_lengths[i];
|
||||
static_assert(len_ % p_len_over_h[i] == 0,
|
||||
"slice length must be dividable by p_len_over_h");
|
||||
return number<len_ / p_len_over_h[i]>{};
|
||||
},
|
||||
number<x_slice_lengths.size()>{});
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h = reverse_slice_sequence(
|
||||
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
{
|
||||
constexpr auto sliced_y_to_h_lens =
|
||||
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
|
||||
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
|
||||
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
|
||||
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
|
||||
});
|
||||
}
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
|
||||
constexpr auto y_to_h_len =
|
||||
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
|
||||
constexpr auto y_to_h_dims = y_to_h_len.size();
|
||||
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
|
||||
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
|
||||
src_y_prefix_sum[id + 1],
|
||||
1>::type{};
|
||||
|
||||
set_container_subset(
|
||||
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
|
||||
return sliced_h_lens;
|
||||
},
|
||||
typename Encoding::HsLengthss{},
|
||||
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
|
||||
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
|
||||
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
|
||||
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
|
||||
|
||||
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
|
||||
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
|
||||
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
|
||||
// change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
typename Encoding::Ys2RHsMinor>{}),
|
||||
sliced_y_origins,
|
||||
sliced_y_lengths);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
|
||||
Ys2DDescriptor_,
|
||||
StaticTileDistributionEncoding_,
|
||||
TileDistributionDetail_>& distribution)
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(StaticTileDistributionEncoding_{});
|
||||
printf(", ");
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(distribution.ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
printf("ys_to_d_: ");
|
||||
print(distribution.ys_to_d_);
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
899
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
899
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
@@ -0,0 +1,899 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RsLengths_, // sequence<...>
|
||||
typename HsLengthss_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
|
||||
typename Ys2RHsMajor_, // sequence<...>
|
||||
typename Ys2RHsMinor_> // sequence<...>
|
||||
struct tile_distribution_encoding
|
||||
{
|
||||
using RsLengths = remove_cvref_t<RsLengths_>;
|
||||
using HsLengthss = remove_cvref_t<HsLengthss_>;
|
||||
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
|
||||
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
|
||||
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
|
||||
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
|
||||
|
||||
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
|
||||
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
|
||||
|
||||
static constexpr index_t NDimX = HsLengthss::size();
|
||||
static constexpr index_t NDimP = Ps2RHssMajor::size();
|
||||
static constexpr index_t NDimY = Ys2RHsMajor::size();
|
||||
static constexpr index_t NDimR = RsLengths::size();
|
||||
|
||||
// FIXME: move into detail
|
||||
static constexpr auto rs_lengths_ = RsLengths{};
|
||||
static constexpr auto hs_lengthss_ = HsLengthss{};
|
||||
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
|
||||
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
#if !CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
|
||||
"do not support Y dim pointed to R dim");
|
||||
#endif
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
{
|
||||
// ndim_rh_major_, ndim_span_mainor_
|
||||
static constexpr index_t ndim_rh_major_ = NDimX + 1;
|
||||
static constexpr index_t ndim_span_major_ = NDimX;
|
||||
|
||||
// ndims_rhs_minor_[ndim_rh_major_]
|
||||
static constexpr auto ndims_rhs_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return rs_lengths_.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return hs_lengthss_[i - number<1>{}].size();
|
||||
}
|
||||
},
|
||||
number<ndim_rh_major_>{});
|
||||
|
||||
// max_ndim_rh_minor_
|
||||
static constexpr index_t max_ndim_rh_minor_ =
|
||||
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_lengthss_ =
|
||||
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
|
||||
|
||||
// ys_lengths_
|
||||
static constexpr auto ys_lengths_ = [] {
|
||||
array<index_t, NDimY> ys_lengths_tmp{-1};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major_[i];
|
||||
index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
return ys_lengths_tmp;
|
||||
}();
|
||||
|
||||
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_ys_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_ys_tmp;
|
||||
}();
|
||||
|
||||
// ndims_span_minor_[NDimY]
|
||||
static constexpr auto ndims_span_minor_ = [] {
|
||||
array<index_t, NDimX> ndims_span_minor{0};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_span_minor(span_major)++;
|
||||
}
|
||||
|
||||
return ndims_span_minor;
|
||||
}();
|
||||
|
||||
// max_ndim_span_minor_
|
||||
static constexpr index_t max_ndim_span_minor_ =
|
||||
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
|
||||
{-1}};
|
||||
|
||||
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
|
||||
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
|
||||
|
||||
index_t cnt_ndim_span_minor = 0;
|
||||
|
||||
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
|
||||
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
|
||||
|
||||
if(idim_y >= 0)
|
||||
{
|
||||
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
|
||||
|
||||
cnt_ndim_span_minor++;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_span_minor;
|
||||
}();
|
||||
|
||||
// ys_to_span_major_[NDimY]
|
||||
static constexpr auto ys_to_span_major_ =
|
||||
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
|
||||
|
||||
// ys_to_span_minor_[NDimY]
|
||||
static constexpr auto ys_to_span_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
|
||||
static constexpr auto distributed_spans_lengthss_ = [] {
|
||||
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
|
||||
distributed_spans_lengthss{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t rh_major = ys_to_rhs_major_[i];
|
||||
const index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
|
||||
|
||||
const index_t span_major = rh_major - 1;
|
||||
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
|
||||
|
||||
distributed_spans_lengthss(span_major)(span_minor) = h_length;
|
||||
});
|
||||
|
||||
return distributed_spans_lengthss;
|
||||
}();
|
||||
|
||||
// ndims_distributed_spans_minor_[ndim_span_major_]
|
||||
static constexpr auto ndims_distributed_spans_minor_ = [] {
|
||||
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_distributed_spans_minor(span_major)++;
|
||||
});
|
||||
|
||||
return ndims_distributed_spans_minor;
|
||||
}();
|
||||
|
||||
// does_p_own_r_[NDimP][NDimR]
|
||||
static constexpr auto does_p_own_r_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
does_p_own_r(idim_p)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return does_p_own_r;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<bool, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// ps_over_rs_derivative_[NDimP][NDimR]
|
||||
static constexpr auto ps_over_rs_derivative_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
index_t p_over_rh_derivative = 1;
|
||||
|
||||
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
|
||||
}
|
||||
|
||||
p_over_rh_derivative *= rh_length;
|
||||
});
|
||||
});
|
||||
|
||||
return ps_over_rs_derivative;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<index_t, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
|
||||
{
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size_ = HsLengthss{}[i].size();
|
||||
return number<size_>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
return uniformed_h_dim_lengths;
|
||||
}
|
||||
|
||||
// note: this function only count the p dim length along h, not r
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
|
||||
{
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
|
||||
// Y P Y Y P Y P Y
|
||||
// | | |
|
||||
// v v v
|
||||
// return : seq<4, 2 * 4> => seq<4, 8>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto p_len_ = [&]() {
|
||||
array<index_t, NDimX> len_{1};
|
||||
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
|
||||
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
|
||||
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
|
||||
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
|
||||
{
|
||||
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
|
||||
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
|
||||
len_[idim_x_] *= h_length_;
|
||||
}
|
||||
});
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
|
||||
return p_len_over_h_seq_;
|
||||
}
|
||||
|
||||
//
|
||||
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
|
||||
// => return seq<1, 3, 5>
|
||||
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
|
||||
// => return seq<0, 2, 3>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
|
||||
{
|
||||
constexpr auto uniformed_rh_dim_lengths =
|
||||
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
|
||||
|
||||
return uniformed_rh_dim_lengths;
|
||||
}
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
|
||||
|
||||
return rh_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
|
||||
{
|
||||
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto all_ps_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
uniformed_ps_to_rhss_major_,
|
||||
uniformed_ps_to_rhss_minor_);
|
||||
|
||||
return all_ps_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
// TODO: Y can't point to R
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor - NDimR;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple of seq
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
|
||||
{
|
||||
constexpr auto masks_ = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto size_ = HsLengthss{}[i].size();
|
||||
constexpr auto current_y_to_h_mask_ = [&]() {
|
||||
array<index_t, size_> m_{0};
|
||||
// TODO: we loop over all y for each h dim
|
||||
for(auto j = 0; j < NDimY; j++)
|
||||
{
|
||||
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
|
||||
{
|
||||
m_[Ys2RHsMinor{}[j]] = 1;
|
||||
}
|
||||
}
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(current_y_to_h_mask_, size_);
|
||||
},
|
||||
number<NDimX>{});
|
||||
return masks_;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
{
|
||||
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
|
||||
|
||||
constexpr auto sorted_dims = typename sorted_idx::type{};
|
||||
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_histogram =
|
||||
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
|
||||
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
|
||||
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
// Note here y_to_h does not count R dim!
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename encoding, typename shuffle>
|
||||
class tile_distribution_encoding_shuffle;
|
||||
template <typename encoding, index_t... shuffle>
|
||||
class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
|
||||
{
|
||||
template <typename Ys2RHs>
|
||||
using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
|
||||
|
||||
public:
|
||||
using type = tile_distribution_encoding<typename encoding::RsLengths,
|
||||
typename encoding::HsLengthss,
|
||||
typename encoding::Ps2RHssMajor,
|
||||
typename encoding::Ps2RHssMinor,
|
||||
shuffled<typename encoding::Ys2RHsMajor>,
|
||||
shuffled<typename encoding::Ys2RHsMinor>>;
|
||||
};
|
||||
template <typename encoding, typename shuffle>
|
||||
using tile_distribution_encoding_shuffle_t =
|
||||
typename tile_distribution_encoding_shuffle<encoding, shuffle>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
|
||||
{
|
||||
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
|
||||
|
||||
constexpr index_t NDimHMajor = OuterDstr::NDimX;
|
||||
|
||||
using RsLengths =
|
||||
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
|
||||
|
||||
constexpr auto hs_lengthss = generate_tuple(
|
||||
[&](auto i) {
|
||||
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
|
||||
typename InnerDstr::HsLengthss{}[i]);
|
||||
},
|
||||
number<NDimHMajor>{});
|
||||
|
||||
//
|
||||
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
|
||||
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
|
||||
|
||||
// R dimension
|
||||
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
|
||||
|
||||
// Hs dimensions
|
||||
static_for<0, NDimHMajor, 1>{}([&](auto i) {
|
||||
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
|
||||
});
|
||||
|
||||
return rhs_major_2_ndim_outer_rhs_minor_;
|
||||
}();
|
||||
|
||||
// Ps2RHssMinor
|
||||
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
|
||||
[&](auto p) {
|
||||
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
|
||||
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
|
||||
|
||||
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
|
||||
|
||||
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_p_2_rhss_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_p_2_rhss_minor_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
|
||||
},
|
||||
number<InnerDstr::NDimP>{});
|
||||
|
||||
// Ys2RHsMinor
|
||||
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
|
||||
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
|
||||
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
|
||||
|
||||
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_ys_2_rhs_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_ys_2_rhs_minor__;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto ps_2_rhss_major =
|
||||
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
|
||||
|
||||
constexpr auto ps_2_rhss_minor =
|
||||
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
|
||||
|
||||
//
|
||||
constexpr auto ys_2_rhs_major =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
|
||||
|
||||
constexpr auto ys_2_rhs_minor =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
|
||||
|
||||
return tile_distribution_encoding<RsLengths,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// FIXME: increase if fail
|
||||
constexpr index_t max_ndim_r_out = 20;
|
||||
constexpr index_t max_ndim_y_out = 20;
|
||||
|
||||
//
|
||||
constexpr index_t ndim_p = InDstr::NDimP;
|
||||
constexpr index_t ndim_x_in = InDstr::NDimX;
|
||||
constexpr index_t ndim_y_in = InDstr::NDimY;
|
||||
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
|
||||
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
|
||||
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
|
||||
|
||||
// ndims_ps_low
|
||||
constexpr auto ndims_ps_low = generate_array(
|
||||
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
|
||||
|
||||
// is_rh_major_in_for_reduce
|
||||
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
|
||||
{
|
||||
index_t rh_major = reduce_dim_xs_in[i] + 1;
|
||||
|
||||
is_rh_major_in_for_reduce(rh_major) = true;
|
||||
}
|
||||
|
||||
// is_y_in_for_reduce
|
||||
array<bool, ndim_y_in> is_y_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < ndim_y_in; i++)
|
||||
{
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major])
|
||||
{
|
||||
is_y_in_for_reduce(i) = true;
|
||||
}
|
||||
}
|
||||
|
||||
// is_rh_minor_in_for_y_reduce
|
||||
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
if(is_y_in_for_reduce[i])
|
||||
{
|
||||
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
|
||||
// in2out_rh_major
|
||||
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
|
||||
index_t cnt_ndim_rh_major_out = 0;
|
||||
|
||||
for(index_t i = 0; i < ndim_rh_major_in; i++)
|
||||
{
|
||||
if(is_rh_major_in_for_reduce[i])
|
||||
{
|
||||
in2out_rh_major(i) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
in2out_rh_major(i) = cnt_ndim_rh_major_out;
|
||||
|
||||
cnt_ndim_rh_major_out++;
|
||||
}
|
||||
}
|
||||
|
||||
// rs_lengths_out, in2out_rh_minor
|
||||
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
|
||||
|
||||
// loop over input R dim
|
||||
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(i) = InDstr::rs_lengths_[i];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(0)(i) = i;
|
||||
}
|
||||
|
||||
// loop over input H Dim
|
||||
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
|
||||
|
||||
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
|
||||
constexpr auto h_major_in = rh_major_in - I1;
|
||||
|
||||
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major_in])
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
|
||||
|
||||
cnt_ndim_r_out++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_r_out
|
||||
const index_t ndim_r_out = cnt_ndim_r_out;
|
||||
|
||||
// ndims_hs_minor_out, hs_lengthss_out
|
||||
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
|
||||
|
||||
index_t cnt_ndim_x_out = 0;
|
||||
|
||||
static_for<0, ndim_x_in, 1>{}([&](auto i) {
|
||||
if(not is_rh_major_in_for_reduce[i + I1])
|
||||
{
|
||||
// ndims_hs_minor_out
|
||||
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
|
||||
|
||||
// hs_lengthss_out
|
||||
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
|
||||
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
|
||||
|
||||
cnt_ndim_x_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ps_to_rhss_major_out, ps_to_rhss_minor_out
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
|
||||
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
|
||||
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
|
||||
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
|
||||
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
});
|
||||
});
|
||||
|
||||
// ys_to_rhs_major_out, ys_to_rhs_minor_out
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
|
||||
|
||||
index_t cnt_ndim_y_out = 0;
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
if(not is_y_in_for_reduce[i])
|
||||
{
|
||||
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
|
||||
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
|
||||
cnt_ndim_y_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_y_out
|
||||
const index_t ndim_y_out = cnt_ndim_y_out;
|
||||
|
||||
//
|
||||
return make_tuple(ndim_x_out,
|
||||
ndim_p,
|
||||
ndim_y_out,
|
||||
ndim_r_out,
|
||||
ndims_hs_minor_out,
|
||||
ndims_ps_low,
|
||||
rs_lengths_out,
|
||||
hs_lengthss_out,
|
||||
ps_to_rhss_major_out,
|
||||
ps_to_rhss_minor_out,
|
||||
ys_to_rhs_major_out,
|
||||
ys_to_rhs_minor_out);
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
|
||||
|
||||
constexpr index_t ndim_x = impl.template at<0>();
|
||||
constexpr index_t ndim_p = impl.template at<1>();
|
||||
constexpr index_t ndim_y = impl.template at<2>();
|
||||
constexpr index_t ndim_r = impl.template at<3>();
|
||||
constexpr auto ndims_hs_minor = impl.template at<4>();
|
||||
constexpr auto ndims_ps_low = impl.template at<5>();
|
||||
constexpr auto rs_lengths_impl = impl.template at<6>();
|
||||
constexpr auto hs_lengthss_impl = impl.template at<7>();
|
||||
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
|
||||
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
|
||||
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
|
||||
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
|
||||
|
||||
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
|
||||
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
|
||||
constexpr auto ps_to_rhss_major =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ps_to_rhss_minor =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
|
||||
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
|
||||
|
||||
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution_encoding::detail
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
print(const typename tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>::detail& detail_obj)
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("ndim_span_major_: ");
|
||||
print(detail_obj.ndim_span_major_);
|
||||
printf(", ");
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(detail_obj.ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(detail_obj.max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
printf("rhs_lengthss_: ");
|
||||
print(detail_obj.rhs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ys_lengths_: ");
|
||||
print(detail_obj.ys_lengths_);
|
||||
printf(", ");
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(detail_obj.rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
printf("ndims_span_minor_: ");
|
||||
print(detail_obj.ndims_span_minor_);
|
||||
printf(", ");
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(detail_obj.max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_major_: ");
|
||||
print(detail_obj.ys_to_span_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(detail_obj.ys_to_span_minor_);
|
||||
printf(", ");
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(detail_obj.distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(detail_obj.ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(detail_obj.ps_over_rs_derivative_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Free print function for tile_distribution_encoding
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>& encoding)
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
|
||||
printf("rs_lengths_: ");
|
||||
print(encoding.rs_lengths_);
|
||||
printf(", ");
|
||||
printf("hs_lengthss_: ");
|
||||
print(encoding.hs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(encoding.ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(encoding.ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(encoding.ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(encoding.ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
396
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
396
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
@@ -0,0 +1,396 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: support tensors with different distribution
|
||||
template <typename InOutElementFunc,
|
||||
typename... InOutDstrTensors,
|
||||
typename = std::enable_if_t<std::conjunction_v<
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
|
||||
InOutDstrTensors&... inout_dstr_tensors)
|
||||
{
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}(
|
||||
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
|
||||
}
|
||||
|
||||
template <typename InElementFunc,
|
||||
typename... InTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls tile_elementwise_inout with unpacked tuple elements.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple, size_t... I>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t,
|
||||
std::index_sequence<I...>)
|
||||
{
|
||||
return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls the overloaded function, passing an index sequence.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t)
|
||||
{
|
||||
static constexpr auto size = Tuple::size();
|
||||
return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&value](auto& x) {
|
||||
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
|
||||
},
|
||||
dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
|
||||
// sub-dword tensor...
|
||||
template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
|
||||
CK_TILE_DEVICE void
|
||||
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
|
||||
{
|
||||
using elem_type = typename DstrTensors::DataType;
|
||||
constexpr index_t elem_size = sizeof(elem_type);
|
||||
|
||||
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
|
||||
|
||||
// # bytes per write = 4
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
|
||||
{
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
auto& buffer = dstr_tensor.get_thread_buffer();
|
||||
|
||||
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
|
||||
if constexpr(elem_size == 1)
|
||||
{
|
||||
// # elements per write = 4
|
||||
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
|
||||
|
||||
buffer[i_write * 4 + 0] = values.x;
|
||||
buffer[i_write * 4 + 1] = values.y;
|
||||
buffer[i_write * 4 + 2] = values.z;
|
||||
buffer[i_write * 4 + 3] = values.w;
|
||||
}
|
||||
else if constexpr(elem_size == 2)
|
||||
{
|
||||
// # elements per write = 2
|
||||
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
|
||||
|
||||
buffer[i_write * 2 + 0] = values.x;
|
||||
buffer[i_write * 2 + 1] = values.y;
|
||||
}
|
||||
else if constexpr(elem_size == 4)
|
||||
{
|
||||
// # elements per write = 1
|
||||
constexpr elem_type value = 0;
|
||||
|
||||
buffer[i_write] = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "type not supported");
|
||||
}
|
||||
});
|
||||
#else
|
||||
using dvec_t = array<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
|
||||
dstr_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t v>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DstrTensors>
|
||||
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
{
|
||||
set_tile(dstr_tensor, 0);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx94__) || defined(__gfx12__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 4 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin requires the old value, and
|
||||
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
|
||||
// so we prepare an uninitialized variable purposely, and turn off the warning
|
||||
int dummy_old;
|
||||
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
|
||||
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
|
||||
x,
|
||||
true); // true -> WORD1
|
||||
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(y);
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pkrtz_fp16_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 2 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
// TODO: this is rtz cvt, need be very careful
|
||||
for(index_t i = 0; i < thread_buffer_size_pk; i++)
|
||||
{
|
||||
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
|
||||
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
|
||||
}
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
// This API is designed to help compiler to identify pairs of f32 -> fp16/bf16 cast and use
|
||||
// cvt_pk instruction when possible
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 2 == 0);
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
using f16x2_t = std::conditional_t<std::is_same_v<OutDataType, fp16_t>, fp16x2_t, bf16x2_t>;
|
||||
for(index_t i = 0; i < thread_buffer_size / 2; i++)
|
||||
{
|
||||
auto o = type_convert<f16x2_t>(fp32x2_t{
|
||||
in_dstr_tensors.get_thread_buffer()[2 * i + 0],
|
||||
in_dstr_tensors.get_thread_buffer()[2 * i + 1],
|
||||
});
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
|
||||
}
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
using i_type = remove_cvref_t<typename InTensor::DataType>;
|
||||
using o_type = remove_cvref_t<OutDataType>;
|
||||
constexpr index_t i_elem_bytes = sizeof(i_type);
|
||||
constexpr index_t o_elem_bytes = sizeof(o_type);
|
||||
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
|
||||
|
||||
constexpr index_t bulk_size =
|
||||
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
|
||||
static_assert(bulk_size != 0);
|
||||
|
||||
using o_bulk_type =
|
||||
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
|
||||
constexpr index_t iters = thread_buffer_size / bulk_size;
|
||||
constexpr index_t rems = thread_buffer_size % bulk_size;
|
||||
|
||||
// cast the sequence per-bulk
|
||||
static_for<0, iters, 1>{}([&](auto i) {
|
||||
union bulk_wrapper
|
||||
{
|
||||
o_bulk_type bulk{};
|
||||
o_type data[bulk_size];
|
||||
} o_bulk;
|
||||
|
||||
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
|
||||
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
|
||||
o_bulk.data[ib.value] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer()
|
||||
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
});
|
||||
|
||||
// TODO: fixme, should use above!
|
||||
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
// o_bulk.data[0] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
// o_bulk.data[1] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
|
||||
});
|
||||
|
||||
static_for<0, rems, 1>{}([&](auto r) {
|
||||
// TODO: introducing local scratch pad?
|
||||
auto idx = number<iters * bulk_size + r>{};
|
||||
out_dstr_tensor.get_thread_buffer().at(idx) =
|
||||
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
#endif
|
||||
} // namespace impl
|
||||
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
|
||||
#if CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
else if constexpr(std::is_same_v<DstType, fp16_t> &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 2 == 0))
|
||||
return impl::cast_tile_pkrtz_fp16_fp32<DstType, SrcTensor>(src_tensor);
|
||||
#endif
|
||||
#if 0 // currently it causes extra spills in qr_async_vr pipeline of fmha_fwd
|
||||
else if constexpr((std::is_same_v<DstType, fp16_t> || std::is_same_v<DstType, bf16_t>) &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 2 == 0))
|
||||
return impl::cast_tile_pk_fp16bf16_fp32<DstType, SrcTensor>(src_tensor);
|
||||
#endif
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
|
||||
#endif
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InOutElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1366
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
1366
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
File diff suppressed because it is too large
Load Diff
1521
include/ck_tile/core/tensor/tile_window.hpp
Normal file
1521
include/ck_tile/core/tensor/tile_window.hpp
Normal file
File diff suppressed because it is too large
Load Diff
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
* @note This class does not provide any functions to read or modify device memory.
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
*/
|
||||
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_base
|
||||
{
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
// Delegate to child if it implements extra logic
|
||||
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
|
||||
}
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
// Delegate to child if it implements extra movement logic
|
||||
static_cast<TileWindowType_*>(this)->move_extended(step);
|
||||
}
|
||||
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
};
|
||||
|
||||
template <typename TileWindowType_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_>
|
||||
struct tile_window_with_tile_dstr_base
|
||||
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
|
||||
{
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord = decltype(make_tensor_coordinate(
|
||||
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
|
||||
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
struct Traits
|
||||
{
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
using vector_t =
|
||||
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
};
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
1122
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
1122
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
File diff suppressed because it is too large
Load Diff
61
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
61
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TileWindow_>
|
||||
CK_TILE_DEVICE void move_tile_window(TileWindow_& window,
|
||||
const typename TileWindow_::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
return make_tuple(m0_init_value, size_per_issue);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
220
include/ck_tile/core/tensor/transpose_tile.hpp
Normal file
220
include/ck_tile/core/tensor/transpose_tile.hpp
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
|
||||
"Data type for InTensor and OutTensor must be the same!");
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// In swapped Hs case <Y,X> -> <X,Y> tile
|
||||
// we have same rh_major, but reversed rh_minor!
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; });
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = 0;
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) {
|
||||
if constexpr(vec_length_in == 1)
|
||||
return 1;
|
||||
else
|
||||
return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
if constexpr(num_vec_in == 1 || num_vec_out == 1)
|
||||
{
|
||||
// loop over SFC
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
constexpr auto idx_y_in =
|
||||
generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
constexpr auto idx_y_out_tmp =
|
||||
generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
if constexpr(vec_length_in == 1)
|
||||
{
|
||||
|
||||
out_tensor.get_thread_buffer()[number<out_offset>{}] =
|
||||
in_tensor.get_thread_buffer()[number<in_offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
using Vec = array<DataType, vec_length_in>;
|
||||
out_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<out_offset / vec_length_in>{}) =
|
||||
in_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InTileDistr = typename InTensor::StaticTileDistribution;
|
||||
using OutTileDistr = typename OutTensor::StaticTileDistribution;
|
||||
|
||||
using InDstrEncode = typename InTileDistr::DstrEncode;
|
||||
using OutDstrEncode = typename OutTileDistr::DstrEncode;
|
||||
|
||||
using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
|
||||
using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
|
||||
|
||||
// Ys:
|
||||
constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
|
||||
constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = [&]() {
|
||||
if constexpr(std::is_same_v<OutDataType, InDataType>)
|
||||
{
|
||||
return in;
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
}
|
||||
}();
|
||||
|
||||
// Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
|
||||
// we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
|
||||
in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
|
||||
// Any condition on Ps ??
|
||||
// InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
// InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
{
|
||||
detail::transpose_tile2d_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Provided tensors could not be transposed!");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
105
include/ck_tile/core/tensor/update_tile.hpp
Normal file
105
include/ck_tile/core/tensor/update_tile.hpp
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto update_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
161
include/ck_tile/core/utility/debug.hpp
Normal file
161
include/ck_tile/core/utility/debug.hpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct thread_buffer;
|
||||
|
||||
// Usage example: CK_PRINTF<float>{}(tensor);
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF;
|
||||
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINTF<ConvertTo,
|
||||
str_literal<FMTChars...>,
|
||||
str_literal<PREFIXChars...>,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return std::make_tuple(make_str_literal("%8.3f"), T{});
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return std::make_tuple(make_str_literal("%5d"), T{});
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return std::make_tuple(make_str_literal("%5u"), T{});
|
||||
else if constexpr(sizeof(T) == 1)
|
||||
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
|
||||
else if constexpr(sizeof(T) == 2)
|
||||
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
|
||||
else if constexpr(sizeof(T) == 4)
|
||||
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
|
||||
else
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
template <typename T>
|
||||
using default_format_t =
|
||||
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
|
||||
template <typename T>
|
||||
using default_type_t =
|
||||
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>,
|
||||
Args&&... args) const
|
||||
{
|
||||
using FMT1 = std::
|
||||
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data,
|
||||
get_thread_id(),
|
||||
N,
|
||||
args...,
|
||||
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(
|
||||
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... TS, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
|
||||
Args&&... args) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
print(std::forward<T>(x));
|
||||
}
|
||||
template <typename... Ts>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
|
||||
{
|
||||
using base_t = CK_PRINTF<Ts...>;
|
||||
|
||||
template <typename T, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* RAII struct which inserts start/end markers into the generated assembly.
|
||||
*
|
||||
* Usage:
|
||||
* - Create an `AsmScopeMarker` object at the beginning of a scope or code block.
|
||||
* - Its constructor will emit a "CK_ASM_SCOPE_START" marker into the assembly.
|
||||
* - When the object goes out of scope (end of block, return, exception, etc.),
|
||||
* the destructor will emit a "CK_ASM_SCOPE_END" marker.
|
||||
*
|
||||
* Example:
|
||||
* {
|
||||
* [[maybe_unused]] AsmScopeMarker marker; // Emits CK_ASM_SCOPE_START
|
||||
* // ... code you want to delimit in assembly ...
|
||||
* } // marker goes out of scope → Emits CK_ASM_SCOPE_END
|
||||
*
|
||||
*/
|
||||
struct AsmScopeMarker
|
||||
{
|
||||
// in some future version of clang we might be able to use string_view to customize
|
||||
CK_TILE_HOST_DEVICE AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_START"); }
|
||||
CK_TILE_HOST_DEVICE ~AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_END"); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
220
include/ck_tile/core/utility/env.hpp
Normal file
220
include/ck_tile/core/utility/env.hpp
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_ERROR(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cerr << "[CK_TILE_ERROR] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_INFO(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cout << "[CK_TILE_INFO] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <size_t N>
|
||||
bool is_any_of(const char* const (&names)[N], const std::string& str)
|
||||
{
|
||||
return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) {
|
||||
return str == inner_str;
|
||||
});
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ParseEnvVal
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct ParseEnvVal<bool>
|
||||
{
|
||||
static bool parse_env_var_value(const char* vp)
|
||||
{
|
||||
std::string value_env_str{vp};
|
||||
|
||||
for(auto& c : value_env_str)
|
||||
{
|
||||
if(std::isalpha(c) != 0)
|
||||
{
|
||||
c = std::tolower(static_cast<unsigned char>(c));
|
||||
}
|
||||
}
|
||||
|
||||
if(is_any_of(enabled_names, value_env_str))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(is_any_of(disabled_names, value_env_str))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Invalid value for env variable");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"};
|
||||
static constexpr const char* disabled_names[] = {
|
||||
"disable", "disabled", "0", "no", "off", "false"};
|
||||
};
|
||||
|
||||
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
|
||||
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
|
||||
template <>
|
||||
struct ParseEnvVal<uint64_t>
|
||||
{
|
||||
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ParseEnvVal<std::string>
|
||||
{
|
||||
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct EnvVar
|
||||
{
|
||||
private:
|
||||
T value{};
|
||||
bool is_unset = true;
|
||||
|
||||
public:
|
||||
const T& GetValue() const { return value; }
|
||||
|
||||
bool IsUnset() const { return is_unset; }
|
||||
|
||||
void Unset() { is_unset = true; }
|
||||
|
||||
void UpdateValue(const T& val)
|
||||
{
|
||||
is_unset = false;
|
||||
value = val;
|
||||
}
|
||||
|
||||
explicit EnvVar(const char* const name, const T& def_val)
|
||||
{
|
||||
// NOLINTNEXTLINE (concurrency-mt-unsafe)
|
||||
const char* vp = std::getenv(name);
|
||||
if(vp != nullptr) // a value was provided
|
||||
{
|
||||
is_unset = false;
|
||||
value = ParseEnvVal<T>::parse_env_var_value(vp);
|
||||
}
|
||||
else // no value provided, use default value
|
||||
{
|
||||
value = def_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
// Static inside function hides the variable and provides
|
||||
// thread-safety/locking
|
||||
// Used in global namespace
|
||||
#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \
|
||||
namespace ck_tile::env { \
|
||||
struct name \
|
||||
{ \
|
||||
static_assert(std::is_same_v<name, ::ck_tile::env::name>, \
|
||||
"CK_TILE_DECLARE_ENV* must be used in the global namespace"); \
|
||||
using value_type = type; \
|
||||
static ck_tile::internal::EnvVar<type>& Ref() \
|
||||
{ \
|
||||
static ck_tile::internal::EnvVar<type> var{#name, default_val}; \
|
||||
return var; \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define CK_TILE_ENV(name) \
|
||||
ck_tile::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
inline const std::string& EnvGetString(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsEnabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsDisabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline uint64_t EnvValue(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsUnset(EnvVar)
|
||||
{
|
||||
return EnvVar::Ref().IsUnset();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
void EnvUnset(EnvVar)
|
||||
{
|
||||
EnvVar::Ref().Unset();
|
||||
}
|
||||
|
||||
/// Updates the cached value of an environment variable
|
||||
template <typename EnvVar, typename ValueType>
|
||||
void UpdateEnvVar(EnvVar, const ValueType& val)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
|
||||
EnvVar::Ref().UpdateValue(val);
|
||||
}
|
||||
|
||||
template <typename EnvVar>
|
||||
void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
{
|
||||
EnvVar::Ref().UpdateValue(
|
||||
ck_tile::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(
|
||||
val.data()));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#pragma clang diagnostic pop
|
||||
275
include/ck_tile/core/utility/functional.hpp
Normal file
275
include/ck_tile/core/utility/functional.hpp
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct applier
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
|
||||
(f(number<Is>{}), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t Size> // == sizeof...(Is)
|
||||
using make_applier = __make_integer_seq<applier, index_t, Size>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N>
|
||||
struct static_for<0, N, 1> : detail::make_applier<N>
|
||||
{
|
||||
using detail::make_applier<N>::operator();
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct static_for_product;
|
||||
template <index_t... Is>
|
||||
struct static_for_product<static_for<Is...>> : public static_for<Is...>
|
||||
{
|
||||
};
|
||||
template <index_t... Is>
|
||||
struct static_for_product<sequence<Is...>> : public static_for<Is...>
|
||||
{
|
||||
};
|
||||
template <index_t I>
|
||||
struct static_for_product<number<I>> : public static_for<0, I, 1>
|
||||
{
|
||||
};
|
||||
template <typename First, typename... Rest>
|
||||
struct static_for_product<First, Rest...>
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
static_for_product<First>{}([=](auto I) {
|
||||
static_for_product<Rest...>{}([=](auto... Is) { //
|
||||
f(I, Is...);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to identity, but takes an additional index parameter as the first argument.
|
||||
// The index is ignored and only the second argument (value) is forwarded.
|
||||
// Useful for indexed element-wise operations where the functor signature requires an index.
|
||||
struct idx_identity
|
||||
{
|
||||
template <typename I, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
|
||||
std::forward<Y>(y).at(number<Js>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
|
||||
std::forward<Y>(y).template at<Js>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// This file should not be included inside tuple.hpp!
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
template <class F, class CurrentUnpackIds>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
|
||||
{
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_one_shot_impl
|
||||
{
|
||||
template <class F, class CurrentUnpackIds, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
|
||||
{
|
||||
constexpr auto r_lens_stride =
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
|
||||
constexpr auto r_upks_stride =
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
|
||||
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// TODO: we may unify static_ford/static_uford in the future
|
||||
//
|
||||
// loop over nd space(sequence) with packs
|
||||
// you must make sure the function passed in has same number of argument
|
||||
//
|
||||
// e.g.
|
||||
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
|
||||
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
|
||||
//
|
||||
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
|
||||
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
|
||||
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
|
||||
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
|
||||
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
|
||||
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
|
||||
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
|
||||
// ...
|
||||
template <class Lengths,
|
||||
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_uford
|
||||
{
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
using L_ = decltype(Lengths{} / Unpacks{});
|
||||
|
||||
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id...)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
|
||||
f, make_tuple(sequence<>{}));
|
||||
}
|
||||
|
||||
// this version is friendly for issue function one by one
|
||||
template <class F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
|
||||
{
|
||||
static_assert(i_access < get_num_of_access());
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
|
||||
decltype(ordered_unpacks),
|
||||
Orders>{}(
|
||||
f, make_tuple(sequence<>{}), number<i_access>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
51
include/ck_tile/core/utility/gemm_validation.hpp
Normal file
51
include/ck_tile/core/utility/gemm_validation.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
inline void
|
||||
validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
|
||||
{
|
||||
if(Layout == "C" && stride < M)
|
||||
{
|
||||
throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
|
||||
std::to_string(stride) + ") must be greater or equal to dim " +
|
||||
std::to_string(M));
|
||||
}
|
||||
if(Layout == "R" && stride < N)
|
||||
{
|
||||
throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
|
||||
std::to_string(stride) + ") must be greater or equal to dim " +
|
||||
std::to_string(N));
|
||||
}
|
||||
}
|
||||
|
||||
inline void validate_gemm_stride(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int Stride_A,
|
||||
int Stride_B,
|
||||
int Stride_C)
|
||||
{
|
||||
// set default stride
|
||||
if(Stride_A <= 0)
|
||||
Stride_A = (a_layout == "R") ? K : M;
|
||||
if(Stride_B <= 0)
|
||||
Stride_B = (b_layout == "R") ? N : K;
|
||||
if(Stride_C <= 0)
|
||||
Stride_C = (c_layout == "R") ? N : M;
|
||||
|
||||
validate_stride(a_layout, M, K, Stride_A, "Stride_A");
|
||||
validate_stride(b_layout, K, N, Stride_B, "Stride_B");
|
||||
validate_stride(c_layout, M, N, Stride_C, "Stride_C");
|
||||
}
|
||||
} // namespace ck_tile
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user