mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
introducing ck_tile! (#1216)
* enable gfx940
* switch between intrinsic mfma routines on mi100/200 and mi300
* fix mfma_int8 on MI300
* disable 2 int8 examples on MI300
* Update cmake-ck-dev.sh
* restore gitignore file
* modify Jenkinsfile to the internal repo
* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx
Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)
---
updated-dependencies:
- dependency-name: rocm-docs-core
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* initial enablement of gfx950
* fix clang format
* disable examples 31 and 41 int8 on gfx950
* add code
* fix build wip
* fix xx
* now can build
* naming
* minor fix
* wip fix
* fix macro for exp2; fix warpgemm a/b in transposedC
* unify as tuple_array
* Update the required Python version to 3.9
* Update executable name in test scripts
* re-structure tuple/array to avoid spill
* Merge function templates
* Fix format
* Add constraint to array<> ctor
* Re-use function
* Some minor changes
* remove wrong code in store_raw()
* fix compile issue in transpose
* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'
* let more integral_constant->constant, and formating
* make sure thread_buffer can be tuple/array
* temp fix buffer_store spill
* not using custom data type by default, now we can have ISA-level same code as opt_padding
* fix compile error, fp8 not ready now
* fix fp8 duplicated move/shift/and/or problem
* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode
* fix scratch in fp8 kernel
* update some readme
* fix merge from upstream
* sync with upstream
* sync upstream again
* sync 22
* remove unused
* fix clang-format
* update README of ck_tile example
* fix several issue
* let python version to be 3.8 as minimal
* remove ck_tile example from default cmake target like all/install/check
* remove mistake
* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg
* fix some bug in group-mode masking and codegen. update README
* F8 quantization for FMHA forward (#1224)
* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline
* Add element function to fmha api
* Adjust P elementwise function
* Fix bug of elementwise op, our elementwise op is not inout
* Add some elementwise op, prepare to quantization
* Let generate.py can generate different elementwise function
* To prevent compiler issue, remove the elementwise function we have not used.
* Remove f8 pipeline, we should share the same pipeline even in f8
* Remove remove_cvref_t
* Avoid warning
* Fix wrong fp8 QK/KV block gemm setting
* Check fp8 rounding error in check_err()
* Set fp8 rounding error for check_err()
* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode
* 1. codgen the f8 api and kernel
2. f8 host code
* prevent warning in filter mode
* Remove not-in-use elementwise function kargs
* Remove more not-in-use elementwise function kargs
* Small refinements in C++ source files
* Use conditional_t<> to simplify code
* Support heterogeneous argument for binary function types
* Re-use already-existing scales<> functor template
* Fix wrong value produced by saturating
* Generalize the composes<> template
* Unify saturates<> implementation
* Fix type errors in composes<>
* Extend less_equal<>
* Reuse the existing template less_equal<> in check_err()
* Add equal<float> & equal<double>
* Rename check_err() parameter
* Rename check_err() parameter
* Add FIXME comment for adding new macro in future
* Remove unnecessary cast to void
* Eliminate duplicated code
* Avoid dividing api pool into more than 2 groups
* Use more clear variable names
* Use affirmative condition in if stmt
* Remove blank lines
* Donot perfect forwarding in composes<>
* To fix compile error, revert generate.py back to 4439cc107d
* Fix bug of p element function
* Add compute element op to host softmax
* Remove element function in api interface
* Extract user parameter
* Rename pscale and oscale variable
* rename f8 to fp8
* rename more f8 to fp8
* Add pipeline::operator() without element_functor
* 1. Remove deprecated pipeline enum
2. Refine host code parameter
* Use quantization range as input
* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description
* Refine description
* prevent early return
* unify _squant kernel name in cpp, update README
* Adjust the default range.
* Refine error message and bias range
* Add fp8 benchmark and smoke test
* fix fp8 swizzle_factor=4 case
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
18
include/ck_tile/core/README.md
Normal file
18
include/ck_tile/core/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# ck_tile/core #
|
||||
|
||||
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
|
||||
|
||||
```
|
||||
algorithm/
|
||||
coordinate transform and some other reusable algorithm
|
||||
arch/
|
||||
contains some basic device building block like mma, buffer addressing, etc...
|
||||
container/
|
||||
contains basic container data structure, array/sequence/tuple/...
|
||||
numeric/
|
||||
data type, and data type related math
|
||||
tensor/
|
||||
tensor descriptors and tile level API
|
||||
utility/
|
||||
other utility function for both host/device
|
||||
```
|
||||
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1672
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
1672
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
File diff suppressed because it is too large
Load Diff
166
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
166
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess,
|
||||
bool SnakeCurved = true> // # of scalars per access in each dimension
|
||||
struct space_filling_curve
|
||||
{
|
||||
static constexpr index_t TensorSize =
|
||||
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
|
||||
static_assert(0 < TensorSize,
|
||||
"space_filling_curve should be used to access a non-empty tensor");
|
||||
|
||||
static constexpr index_t nDim = TensorLengths::size();
|
||||
|
||||
using Index = multi_index<nDim>;
|
||||
|
||||
static constexpr index_t ScalarPerVector =
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
|
||||
static constexpr auto dim_access_order = DimAccessOrder{};
|
||||
static constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(ordered_access_lengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
|
||||
number<AccessIdx1dTail>)
|
||||
{
|
||||
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
|
||||
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
|
||||
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
|
||||
return idx_tail - idx_head;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
|
||||
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
|
||||
*/
|
||||
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
|
||||
#else
|
||||
|
||||
constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
|
||||
|
||||
constexpr auto idx_1d = number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
|
||||
id = res / access_strides[kdim].value;
|
||||
res -= id * access_strides[kdim].value;
|
||||
});
|
||||
|
||||
return id;
|
||||
};
|
||||
|
||||
constexpr auto id = compute_index_impl(idim);
|
||||
return number<id>{};
|
||||
};
|
||||
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
statically_indexed_array<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto idim) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, idim, 1>{}(
|
||||
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
|
||||
|
||||
forward_sweep_(idim) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate multi-dim tensor index
|
||||
auto idx_md = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
ScalarsPerAccess{};
|
||||
}();
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: rename this function
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
2031
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
2031
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
File diff suppressed because it is too large
Load Diff
93
include/ck_tile/core/arch/arch.hpp
Normal file
93
include/ck_tile/core/arch/arch.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct address_space_enum
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
sgpr,
|
||||
vgpr,
|
||||
};
|
||||
|
||||
enum struct memory_operation_enum
|
||||
{
|
||||
set,
|
||||
atomic_add,
|
||||
atomic_max,
|
||||
add
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
|
||||
|
||||
// TODO: deprecate these
|
||||
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop()
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\
|
||||
s_nop 0 \n \
|
||||
" ::);
|
||||
#else
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
62
include/ck_tile/core/arch/utility.hpp
Normal file
62
include/ck_tile/core/arch/utility.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: we have "memory" clobber here because this inline asm is used for async copy
|
||||
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
|
||||
}
|
||||
|
||||
// NOTE: this is an immediate value
|
||||
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
156
include/ck_tile/core/config.hpp
Normal file
156
include/ck_tile/core/config.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
|
||||
#endif
|
||||
|
||||
// in the old rocm period, we have to use tuple array implementation to implement this
|
||||
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
|
||||
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
|
||||
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
// if using tuple-array as thread_buffer implementation, need to support {} brace init
|
||||
// ... with similiar behavior as array
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
|
||||
#else
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define CK_TILE_USE_LAUNCH_BOUNDS 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TIME_KERNEL
|
||||
#define CK_TILE_TIME_KERNEL 1
|
||||
#endif
|
||||
|
||||
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_TILE_MIN_BLOCK_PER_CU 2
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
|
||||
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
|
||||
#define CK_TILE_USE_AMD_BUFFER_STORE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
|
||||
#endif
|
||||
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
251
include/ck_tile/core/container/array.hpp
Normal file
251
include/ck_tile/core/container/array.hpp
Normal file
@@ -0,0 +1,251 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use aggregate initialization for this type
|
||||
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
|
||||
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
|
||||
// use make_array_with({...}) to construct an array with compatible behavior as old ck
|
||||
// TODO: manually added constructor same as old ck
|
||||
template <typename T_, index_t N_>
|
||||
struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
// TODO: do we need this?
|
||||
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
|
||||
// union {
|
||||
value_type data[N];
|
||||
// bulk_type __content;
|
||||
//};
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
|
||||
{
|
||||
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
|
||||
static_assert(list_size <= N, "out of bound");
|
||||
|
||||
index_t i = 0;
|
||||
value_type vlast = value_type{};
|
||||
|
||||
for(const value_type& val : ilist)
|
||||
{
|
||||
data[i] = val;
|
||||
vlast = val;
|
||||
++i;
|
||||
}
|
||||
for(; i < N; ++i)
|
||||
{
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y,
|
||||
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
|
||||
std::is_constructible_v<Y, value_type>>>
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
// clang-format off
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
#if 0
|
||||
template <typename ArrayLike>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
|
||||
{
|
||||
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
|
||||
for(index_t i = 0; i < size(); ++i)
|
||||
{
|
||||
data[i] = arr[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
// type punning (strict aliasing) member functions for read/write
|
||||
// aliasing this array of type "T", "N" elements
|
||||
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
|
||||
#define AR_AS_COM_() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
#undef AR_AS_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// empty Array
|
||||
|
||||
template <typename T>
|
||||
struct array<T, 0>
|
||||
{
|
||||
using value_type = T;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<array<T, N>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
namespace details {
|
||||
template <class>
|
||||
struct is_ref_wrapper : std::false_type
|
||||
{
|
||||
};
|
||||
template <class T>
|
||||
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
|
||||
|
||||
template <class D, class...>
|
||||
struct return_type_helper
|
||||
{
|
||||
using type = D;
|
||||
};
|
||||
template <class... Ts>
|
||||
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
|
||||
{
|
||||
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
|
||||
"Ts cannot contain reference_wrappers when D is void");
|
||||
};
|
||||
|
||||
template <class D, class... Ts>
|
||||
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
|
||||
} // namespace details
|
||||
|
||||
template <typename D = void, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
|
||||
{
|
||||
return {std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
// // make empty array
|
||||
// template <typename T>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
// {
|
||||
// return array<T, 0>{};
|
||||
// }
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
|
||||
{
|
||||
return array<T, Size>(ilist);
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
for(index_t i = 0; i < Size; ++i)
|
||||
{
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
static_assert(N <= X::size(), "");
|
||||
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
499
include/ck_tile/core/container/container_helper.hpp
Normal file
499
include/ck_tile/core/container/container_helper.hpp
Normal file
@@ -0,0 +1,499 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& new2old)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [new_pos, old_pos] : new2old)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& old2new)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [old_pos, new_pos] : old2new)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
// reorder tuple
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder sequence
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
|
||||
sequence<IRs...> /* old2new */)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i + number<IStep>{}, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<IBegin>{}, init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename ROld,
|
||||
index_t I,
|
||||
index_t IEnd,
|
||||
index_t IStep>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
|
||||
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
if constexpr(IEnd > IBegin)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return init;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
|
||||
{
|
||||
#if 0
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
});
|
||||
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
#else
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
for(index_t i = NSize - 1; i > 0; --i)
|
||||
{
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
}
|
||||
|
||||
y(0) = r;
|
||||
|
||||
return y;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
|
||||
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
return container_reverse_exclusive_scan_impl(
|
||||
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
|
||||
template <typename... Xs, typename Reduce, typename TData>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_array<T>(arr[Is]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_tuple(tup[number<Is>{}]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return tuple<>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
for(index_t i = 0; i < picks.size(); ++i)
|
||||
{
|
||||
y(picks[i]) = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
|
||||
{
|
||||
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
}
|
||||
|
||||
// return the index of first occurance in the sequence.
|
||||
// return seq.size(), if not found
|
||||
template <index_t... Is>
|
||||
constexpr index_t container_find(sequence<Is...> seq, index_t value)
|
||||
{
|
||||
for(auto i = 0; i < seq.size(); i++)
|
||||
{
|
||||
if(seq[i] == value)
|
||||
return i;
|
||||
}
|
||||
|
||||
return seq.size();
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
{
|
||||
using Seq = sequence<Is...>;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Seq::at(i);
|
||||
return number<tmp>{};
|
||||
},
|
||||
number<Seq::size()>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, a_size, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#else
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
164
include/ck_tile/core/container/map.hpp
Normal file
164
include/ck_tile/core/container/map.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// naive map
|
||||
template <typename key, typename data, index_t max_size = 128>
|
||||
struct map
|
||||
{
|
||||
using pair_type = tuple<key, data>;
|
||||
using impl_type = array<pair_type, max_size>;
|
||||
|
||||
impl_type impl_;
|
||||
index_t size_;
|
||||
|
||||
struct iterator
|
||||
{
|
||||
impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
struct const_iterator
|
||||
{
|
||||
const impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
|
||||
{
|
||||
for(index_t i = 0; i < size(); i++)
|
||||
{
|
||||
if(impl_[i].template at<0>() == k)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return size_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
|
||||
{
|
||||
return const_iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
|
||||
{
|
||||
return iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
|
||||
{
|
||||
const auto it = find(k);
|
||||
|
||||
// FIXME
|
||||
// assert(it.pos_ < size());
|
||||
|
||||
return impl_[it.pos_].template at<1>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
|
||||
{
|
||||
auto it = find(k);
|
||||
|
||||
// if entry not found
|
||||
if(it.pos_ == size())
|
||||
{
|
||||
impl_(it.pos_).template at<0>() = k;
|
||||
size_++;
|
||||
}
|
||||
|
||||
// FIXME
|
||||
// assert(size_ <= max_size);
|
||||
|
||||
return impl_(it.pos_).template at<1>();
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
|
||||
{
|
||||
return const_iterator{impl_, size_};
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: this structure is not intented to be used by user
|
||||
template <index_t MaxSize>
|
||||
struct meta_data_buffer
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
|
||||
: buffer_{}, size_{0}
|
||||
{
|
||||
push(x, xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
|
||||
{
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
auto tmp = bit_cast<array<std::byte, size>>(data);
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
buffer_(size_) = tmp[i];
|
||||
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
|
||||
{
|
||||
push(x);
|
||||
push(xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
|
||||
{
|
||||
T data;
|
||||
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
data = bit_cast<T>(tmp);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
auto data = bit_cast<T>(tmp);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
//
|
||||
array<std::byte, MaxSize> buffer_;
|
||||
index_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
100
include/ck_tile/core/container/multi_index.hpp
Normal file
100
include/ck_tile/core/container/multi_index.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Don't use tihs directly. This is for old CK's internal usage,
|
||||
// in the future always use array instead
|
||||
template <index_t N>
|
||||
using multi_index = array<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = index_t * multi_index
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
|
||||
{
|
||||
multi_index<NSize> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = multi_index * index_t
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1114
include/ck_tile/core/container/sequence.hpp
Normal file
1114
include/ck_tile/core/container/sequence.hpp
Normal file
File diff suppressed because it is too large
Load Diff
78
include/ck_tile/core/container/span.hpp
Normal file
78
include/ck_tile/core/container/span.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
|
||||
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
|
||||
// TODO: do we need in device consider this is pointer?
|
||||
template <typename T>
|
||||
class span
|
||||
{
|
||||
public:
|
||||
using element_type = T;
|
||||
using value_type = std::remove_cv_t<element_type>;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = element_type*;
|
||||
using const_pointer = const element_type*;
|
||||
using reference = element_type&;
|
||||
using const_reference = const element_type&;
|
||||
using iterator = pointer;
|
||||
using const_iterator = pointer;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
|
||||
{
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
|
||||
: span(arr.data(), N)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
|
||||
: span(container.data(), container.size())
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
|
||||
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
|
||||
{
|
||||
return *(begin() + idx);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
pointer ptr_;
|
||||
size_type size_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = tuple_array<T, N>;
|
||||
|
||||
#else
|
||||
|
||||
// consider mark this struct as deprecated
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = array<T, N>;
|
||||
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
#if 0
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty statically_indexed_array
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
165
include/ck_tile/core/container/thread_buffer.hpp
Normal file
165
include/ck_tile/core/container/thread_buffer.hpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = tuple_array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_tuple(ts...);
|
||||
}
|
||||
#else
|
||||
|
||||
#if 0
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template<typename T_, index_t N_>
|
||||
struct thread_buffer {
|
||||
using value_type = remove_cvref_t<T_>;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
value_type data[N];
|
||||
|
||||
// TODO: this ctor can't ignore
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE auto & get() {return data; }
|
||||
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
|
||||
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
static_assert(N % kSPerX == 0);
|
||||
|
||||
union {
|
||||
thread_buffer<X_, N / kSPerX> data {};
|
||||
// tuple_array<value_type, kSPerX> sub_data;
|
||||
value_type sub_data[N];
|
||||
} vx;
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto j) { vx.sub_data[j] = data[j]; });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data {};
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx;
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data;
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx {x};
|
||||
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
|
||||
#undef TB_COMMON_AS
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
781
include/ck_tile/core/container/tuple.hpp
Normal file
781
include/ck_tile/core/container/tuple.hpp
Normal file
@@ -0,0 +1,781 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <utility>
|
||||
#include <initializer_list>
|
||||
|
||||
#ifndef CK_TILE_TUPLE_IMPL
|
||||
#define CK_TILE_TUPLE_IMPL 1
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl;
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
|
||||
|
||||
namespace impl {
|
||||
|
||||
// the place where content is stored
|
||||
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
|
||||
struct tuple_object
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
#endif
|
||||
T element;
|
||||
};
|
||||
|
||||
// NOTE: we return a instance(not a reference) if content is empty
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
}
|
||||
|
||||
template <typename index_seq, typename... T>
|
||||
struct tuple_base;
|
||||
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#define _ILE() (std::initializer_list<U>{}.size() - 1)
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
|
||||
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
|
||||
{
|
||||
}
|
||||
#undef _ILE
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
|
||||
: tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <class U,
|
||||
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
|
||||
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
|
||||
"wrong! inconsistent size");
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <class... T>
|
||||
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr auto size() { return sizeof...(T); }
|
||||
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
|
||||
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <
|
||||
typename U,
|
||||
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U,
|
||||
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
|
||||
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
|
||||
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
};
|
||||
|
||||
// template <class... T>
|
||||
// CK_TILE_HOST_DEVICE constexpr
|
||||
// tuple<T...>
|
||||
// make_tuple(T const&... t)
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
}
|
||||
});
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
// here xs is always a lvalue as function arg
|
||||
// Xs may deduced as (e.g try to pass in a integer in following cases)
|
||||
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
|
||||
// 2). if pass in a const lvalue -> Xs is "const int &"
|
||||
// 3). if pass in a non-const lvalue -> Xs is "int &"
|
||||
// so the return type of std::forward will dependes on Xs
|
||||
// 1). std::forward -> int&&
|
||||
// 2). std::forward -> const int&
|
||||
// 3). std::forward -> int&
|
||||
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
template <typename... Args>
|
||||
constexpr tuple<Args&...> tie(Args&... args) noexcept
|
||||
{
|
||||
return {args...};
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// be very careful using this type (because we want the internal type)
|
||||
// template deduction will fail if infering the inner type
|
||||
// e.g.
|
||||
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
|
||||
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
|
||||
// -> compiler will fail to deduce this type, because this is under non-deduced context
|
||||
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
|
||||
// contexts")
|
||||
//
|
||||
// -> use this instead
|
||||
// template<typename Tup> void foo(const Tup&) {}
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return tie(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
|
||||
const tuple<Y&...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
// Support any number of tuples to concat (also 1)
|
||||
template <typename... X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
|
||||
{
|
||||
return tx;
|
||||
}
|
||||
|
||||
template <typename... X, typename... Tuples>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
|
||||
{
|
||||
return concat_tuple(tx, concat_tuple(tuples...));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
|
||||
{
|
||||
return make_tuple(t);
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
|
||||
{
|
||||
if constexpr(Depth == MaxDepth)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
else
|
||||
{
|
||||
return unpack(
|
||||
[&](auto&&... ts) {
|
||||
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
|
||||
},
|
||||
t);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using Idx = number<tuple<Ts...>::size() - i - 1>;
|
||||
return t.at(Idx{});
|
||||
},
|
||||
number<tuple<Ts...>::size()()>{});
|
||||
}
|
||||
|
||||
// Reduce tuple values in specific range using Function
|
||||
template <index_t Idx, index_t End, typename F, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
|
||||
{
|
||||
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
|
||||
if constexpr(Idx + 1 == End)
|
||||
{
|
||||
return t.at(number<Idx>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
|
||||
{
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
|
||||
{
|
||||
return max(tuple_depth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
{
|
||||
constexpr index_t n0 = sizeof...(Seqs);
|
||||
|
||||
constexpr index_t max_n1 = [&] {
|
||||
index_t max_n1_ = 0;
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
|
||||
});
|
||||
|
||||
return max_n1_;
|
||||
}();
|
||||
|
||||
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
|
||||
});
|
||||
|
||||
return a_of_a;
|
||||
}
|
||||
|
||||
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
|
||||
// is the alias of the latter. This is because compiler cannot infer the NSize if
|
||||
// using MultiIndex<NSize>
|
||||
// TODO: how to fix this?
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = scalar * MultiIndex
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = MultiIndex * scalar
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include <tuple>
|
||||
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
||||
namespace std {
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: std::tuple_element<I, const std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
342
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
342
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
@@ -0,0 +1,342 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class bf16_rounding_mode
|
||||
{
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
constexpr bfloat16_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const unsigned int& x)
|
||||
: data(float_to_bf16_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<bfloat16_t>
|
||||
{
|
||||
using type = ushort;
|
||||
};
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x)
|
||||
{
|
||||
return static_cast<double>(bf16_to_float_raw(x));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeeeeee mmmmmmm
|
||||
// 0 01111110 0000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x)
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = bit_cast<bf16_raw_t>(x);
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t sqrt(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
} // namespace ck_tile
|
||||
871
include/ck_tile/core/numeric/float8.hpp
Normal file
871
include/ck_tile/core/numeric/float8.hpp
Normal file
@@ -0,0 +1,871 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp8 rounding modes
|
||||
// use standard for rounding to nearest, the faster one
|
||||
// use stochastic for stochastic rounding, helps to avoid error accumulation
|
||||
enum class fp8_rounding_mode
|
||||
{
|
||||
standard = 0,
|
||||
stochastic
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________NANOO_________________ | ______________IEEE________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
static constexpr int mantissa = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr float8_e4m3_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e4m3_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
constexpr float8_e4m3_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const unsigned int& x)
|
||||
: data(float_to_fp8_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
using fp8_t = float8_e4m3_t;
|
||||
using fp8_raw_t = typename fp8_t::raw_type;
|
||||
|
||||
struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
static constexpr int exponent = 5;
|
||||
static constexpr int mantissa = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr float8_e5m2_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e5m2_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
constexpr float8_e5m2_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const unsigned int& x)
|
||||
: data(float_to_bf8_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
using bf8_t = float8_e5m2_t;
|
||||
using bf8_raw_t = typename bf8_t::raw_type;
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<fp8_t>
|
||||
{
|
||||
using type = _BitInt(8);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct native_t<bf8_t>
|
||||
{
|
||||
using type = unsigned _BitInt(8);
|
||||
};
|
||||
|
||||
#else
|
||||
using fp8_t = _BitInt(8);
|
||||
using fp8_raw_t = uint8_t;
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
using bf8_raw_t = uint8_t;
|
||||
#endif
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
namespace impl {
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_traits<Y>::exp;
|
||||
constexpr int out_mant = numeric_traits<Y>::mant;
|
||||
|
||||
// original type exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_traits<X>::exp;
|
||||
constexpr int in_mant = numeric_traits<X>::mant;
|
||||
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
constexpr Y nan_code =
|
||||
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
#else
|
||||
constexpr Y nan_code = 0x80;
|
||||
#endif
|
||||
|
||||
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
using T_bitwise = typename numeric_traits<X>::bitwise_type;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
head = x_bitwise & numeric_traits<X>::head_mask;
|
||||
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
bias = numeric_traits<X>::bias;
|
||||
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return nan_code;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((x_bitwise & nan_mask) == nan_mask)
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
||||
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
||||
// exponent and mantissa again3
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
|
||||
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// out_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
if(exponent == 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
||||
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
|
||||
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
|
||||
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
|
||||
In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = out_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= out_denormal_act_exponent)
|
||||
{
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
||||
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implict 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = out_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff =
|
||||
0; // exponent_diff=0 does not mean there is no difference for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
|
||||
(1 << (in_mant - out_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
|
||||
shift right as shift right could rip off some residual part and make something not midpoint look
|
||||
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
|
||||
midpoint, but after shift right by 4 bits, it would look like midpoint. */
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << in_mant);
|
||||
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
||||
out_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
bool odd =
|
||||
mantissa &
|
||||
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1 << in_mant) & mantissa)
|
||||
{
|
||||
out_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1 << (in_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
// No need to make 1 implicit now as it will be addressed later
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (in_mant - out_mant);
|
||||
|
||||
if(out_exponent > max_exp)
|
||||
{
|
||||
if(clip)
|
||||
{
|
||||
mantissa = (1 << out_mant) - 1;
|
||||
out_exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return __builtin_bit_cast(
|
||||
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return __builtin_bit_cast(Y,
|
||||
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
|
||||
(out_exponent << out_mant) | mantissa));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_traits<X>::exp;
|
||||
constexpr int in_mant = numeric_traits<X>::mant;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_traits<Y>::exp;
|
||||
constexpr int out_mant = numeric_traits<Y>::mant;
|
||||
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
|
||||
|
||||
// prepare the codes
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_traits<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0;
|
||||
|
||||
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
|
||||
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
|
||||
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_raw == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x_raw >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
|
||||
int exponent = (x_raw & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
T_bitwise retval;
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x_raw == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x_raw == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
}
|
||||
|
||||
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x_raw;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - in_mant);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << in_mant) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= out_mant - in_mant;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << out_mant;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// check datatypes
|
||||
constexpr bool is_half = std::is_same<X, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<X, float>::value;
|
||||
static_assert(is_half || is_float, "Only half and float can be casted.");
|
||||
|
||||
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
|
||||
{
|
||||
// check datatype
|
||||
constexpr bool is_half = std::is_same<Y, half_t>::value;
|
||||
constexpr bool is_float = std::is_same<Y, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported.");
|
||||
|
||||
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
|
||||
fp8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
|
||||
bf8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
|
||||
fp8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
|
||||
bf8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
else return fp8_raw_t{0};
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return bf8_raw_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
|
||||
{
|
||||
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
|
||||
{
|
||||
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<fp8_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 8;
|
||||
#else
|
||||
static constexpr int bias = 7;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<bf8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 16;
|
||||
#else
|
||||
static constexpr int bias = 15; // IEEE
|
||||
#endif
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<fp8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t min()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// bin : 7 6543 210
|
||||
// bits: s eeee mmm
|
||||
// 0 0110 000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric<bf8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t min()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// bin : 7 65432 10
|
||||
// bits: s eeeee mm
|
||||
// 0 01110 00 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp8_t abs(const fp8_t& x)
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const fp8_t& x)
|
||||
{
|
||||
uint8_t xx = bit_cast<fp8_raw_t>(x);
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bf8_t abs(const bf8_t& x)
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bf8_t& x)
|
||||
{
|
||||
uint8_t xx = bit_cast<bf8_raw_t>(x);
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
} // namespace ck_tile
|
||||
385
include/ck_tile/core/numeric/half.hpp
Normal file
385
include/ck_tile/core/numeric/half.hpp
Normal file
@@ -0,0 +1,385 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp16_hip_t = _Float16; // most of hip internal function use this type
|
||||
using fp16_raw_t = uint16_t;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = fp16_raw_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// constructor
|
||||
constexpr half_t() : data{} {}
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const unsigned int& x)
|
||||
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
{
|
||||
return static_cast<int>(fp16_to_float_hip(to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<half_t>
|
||||
{
|
||||
using type = _Float16;
|
||||
};
|
||||
|
||||
using fp16_t = half_t;
|
||||
using fp16_raw_t = typename half_t::raw_type;
|
||||
#else
|
||||
using fp16_t = _Float16;
|
||||
using half_t = _Float16;
|
||||
using fp16_raw_t = ushort;
|
||||
#endif
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
{
|
||||
return static_cast<double>(fp16_to_float_hip(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
return __float2half(x);
|
||||
// return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeee mmmmmmmmmm
|
||||
// 0 01110 0000000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t zero()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// arithmetic
|
||||
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
|
||||
{
|
||||
return __heq(x.to_fp16(), y.to_fp16());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
13
include/ck_tile/core/numeric/integer.hpp
Normal file
13
include/ck_tile/core/numeric/integer.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
83
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
83
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
using number = constant<v>;
|
||||
|
||||
template <long_index_t v>
|
||||
using long_number = constant<v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
|
||||
{ \
|
||||
return constant<(OP x)>{}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_BINARY_OP(OP) \
|
||||
template <auto x, auto y> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
|
||||
{ \
|
||||
return constant<(x OP y)>{}; \
|
||||
}
|
||||
|
||||
CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
CK_TILE_LEFT_UNARY_OP(*)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
CK_TILE_BINARY_OP(*)
|
||||
CK_TILE_BINARY_OP(/)
|
||||
CK_TILE_BINARY_OP(%)
|
||||
CK_TILE_BINARY_OP(&)
|
||||
CK_TILE_BINARY_OP(|)
|
||||
CK_TILE_BINARY_OP(^)
|
||||
CK_TILE_BINARY_OP(<<)
|
||||
CK_TILE_BINARY_OP(>>)
|
||||
CK_TILE_BINARY_OP(&&)
|
||||
CK_TILE_BINARY_OP(||)
|
||||
CK_TILE_BINARY_OP(==)
|
||||
CK_TILE_BINARY_OP(!=)
|
||||
CK_TILE_BINARY_OP(>)
|
||||
CK_TILE_BINARY_OP(<)
|
||||
CK_TILE_BINARY_OP(>=)
|
||||
CK_TILE_BINARY_OP(<=)
|
||||
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
} // namespace ck_tile
|
||||
539
include/ck_tile/core/numeric/math.hpp
Normal file
539
include/ck_tile/core/numeric/math.hpp
Normal file
@@ -0,0 +1,539 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Scale, Scale lhs>
|
||||
struct scales_c
|
||||
{
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
struct scales
|
||||
{
|
||||
static_assert(std::is_copy_constructible_v<Scale>);
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct plus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ plus()->plus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct minus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct minus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ minus()->minus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct multiplies
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multiplies<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ multiplies()->multiplies<void, void>;
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimize
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
|
||||
return (a + b - number<1>{}) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
return (x + y - number<1>{}) / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
|
||||
{
|
||||
return y * integer_divide_ceil(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr float max(float x, float y)
|
||||
{
|
||||
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr double max(double x, double y)
|
||||
{
|
||||
return __builtin_fmax(x, y); // maybe still v_max3_f32
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
|
||||
{
|
||||
return X > y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
|
||||
{
|
||||
return x > Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return max(x, max(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr float min(float x, float y)
|
||||
{
|
||||
return __builtin_fminf(x, y);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr double min(double x, double y)
|
||||
{
|
||||
return __builtin_fmin(x, y);
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
|
||||
{
|
||||
return X < y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
|
||||
{
|
||||
return x < Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
|
||||
{
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
|
||||
{
|
||||
constexpr auto r = gcd(X, Y);
|
||||
|
||||
return number<r>{};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
|
||||
{
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
}
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ equal()->equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct equal<float, float>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
|
||||
{
|
||||
return bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equal<double, double>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
|
||||
{
|
||||
return bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct less<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less()->less<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less_equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct less_equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less_equal()->less_equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct less_equal<float, float>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
|
||||
{
|
||||
return lhs < rhs || bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct less_equal<double, double>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
|
||||
{
|
||||
return lhs < rhs || bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
return 1 << (32 - clz(x - 1));
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
|
||||
{
|
||||
constexpr index_t y = next_power_of_two(X);
|
||||
return number<y>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
return x == (1 << integer_log2_floor(x));
|
||||
}
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct log2e;
|
||||
|
||||
template <>
|
||||
struct log2e<double>
|
||||
{
|
||||
static constexpr double value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct log2e<float>
|
||||
{
|
||||
static constexpr float value = C_LOG2E;
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
float abs(const float& x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = bit_cast<uint32_t>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
191
include/ck_tile/core/numeric/numeric.hpp
Normal file
191
include/ck_tile/core/numeric/numeric.hpp
Normal file
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this struct has the information of
|
||||
// 1. limit of a certain type, simliar to std::numeric_limits
|
||||
// 2. some pre-defined value, zero, one...
|
||||
//
|
||||
template <typename T>
|
||||
struct numeric
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T log2e()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
|
||||
{
|
||||
return static_cast<T>(C_LOG2E);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // TODO: integer?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
66
include/ck_tile/core/numeric/type_convert.hpp
Normal file
66
include/ck_tile/core/numeric/type_convert.hpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
|
||||
{
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#else
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
185
include/ck_tile/core/numeric/vector_type.hpp
Normal file
185
include/ck_tile/core/numeric/vector_type.hpp
Normal file
@@ -0,0 +1,185 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this structure is used to pick up the <base> type inside
|
||||
// using xxx = <base> __attribute__((ext_vector_type(N)));
|
||||
// because clang only allow native type + bool in this term (custom type will fail)
|
||||
// overload this structure to let proper <base> type
|
||||
|
||||
template <typename T>
|
||||
struct native_t
|
||||
{
|
||||
using type = remove_cvref_t<T>;
|
||||
};
|
||||
|
||||
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
|
||||
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
// have compiler error
|
||||
namespace impl {
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = typename native_t<remove_cvref_t<T_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type = remove_cvref_t<T>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
using fp64_t = double;
|
||||
using fp64x2_t = double __attribute__((ext_vector_type(2)));
|
||||
using fp64x4_t = double __attribute__((ext_vector_type(4)));
|
||||
|
||||
// fp32
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp32x4_t = float __attribute__((ext_vector_type(4)));
|
||||
using fp32x8_t = float __attribute__((ext_vector_type(8)));
|
||||
using fp32x16_t = float __attribute__((ext_vector_type(16)));
|
||||
using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
// using fp16_t = ...
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
|
||||
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
|
||||
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
// using uint16_t
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
// using int8_t
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
|
||||
#else
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
1068
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
1068
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
File diff suppressed because it is too large
Load Diff
81
include/ck_tile/core/tensor/load_tile.hpp
Normal file
81
include/ck_tile/core/tensor/load_tile.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE auto
|
||||
async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
return tile_window.async_load(lds_tile);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
template <typename T, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_tensor
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
88
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
88
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// placeholder type if we want to opt-out a tile window parameter
|
||||
template <typename WindowLengths_>
|
||||
struct null_tile_window
|
||||
{
|
||||
using BottomTensorView = null_tensor_view;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
|
||||
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
|
||||
: window_lengths_{window_lengths}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
};
|
||||
|
||||
// utility to check if this is a Null Tile Window
|
||||
namespace impl {
|
||||
template <typename>
|
||||
struct is_null_tile_window : public std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename... Ts>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<WindowLengths::size()>& /*origin*/,
|
||||
Ts&&...)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(null_tile_window<WindowLengths>&,
|
||||
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
177
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
177
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// y_dim_out_to_in
|
||||
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
|
||||
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
|
||||
|
||||
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
|
||||
});
|
||||
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// using InVec = typename InVec::type;
|
||||
// using OutVec = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
|
||||
// shuffle
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
|
||||
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY)
|
||||
{
|
||||
detail::shuffle_tile_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOT implemented
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
92
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
92
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
|
||||
// NOTE: This API will override the origin of the tile window!
|
||||
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
|
||||
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
|
||||
|
||||
constexpr auto slice_lengths = slice_ends - slice_begins;
|
||||
|
||||
return make_tile_window(tile.get_bottom_tensor_view(),
|
||||
sequence_to_tuple_of_number(slice_lengths),
|
||||
to_multi_index(slice_begins));
|
||||
}
|
||||
|
||||
template <typename DataType_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Distribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
|
||||
|
||||
sliced_tensor.get_thread_buffer() =
|
||||
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
|
||||
|
||||
return sliced_tensor;
|
||||
}
|
||||
|
||||
template <typename DstDataType_,
|
||||
typename DstStaticTileDistribution_,
|
||||
typename SrcDataType_,
|
||||
typename SrcStaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
|
||||
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
190
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
190
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
@@ -0,0 +1,190 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
static_assert(StaticTileDistribution::is_static(),
|
||||
"wrong! StaticTileDistribution should be known at compile tile");
|
||||
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
|
||||
{
|
||||
return StaticTileDistribution::get_num_of_dimension_x();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
return StaticTileDistribution::get_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
|
||||
{
|
||||
return StaticTileDistribution{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
return StaticTileDistribution::get_distributed_spans();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths>
|
||||
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>) const
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
|
||||
});
|
||||
|
||||
return sliced_thread_data;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
|
||||
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const SlicedThreadData& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
|
||||
}
|
||||
|
||||
//
|
||||
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{};
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
|
||||
ThreadBuffer&& thread_buffer_)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
{
|
||||
const auto partition_index = detail::get_partition_index(tile_distribution);
|
||||
constexpr auto y_indices =
|
||||
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
|
||||
|
||||
const auto x_coord = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
|
||||
|
||||
return x_coord.get_bottom_index();
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
|
||||
distributed_indices);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
93
include/ck_tile/core/tensor/store_tile.hpp
Normal file
93
include/ck_tile/core/tensor/store_tile.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
30
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
30
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// sweep over a span of a distribted tile and apply lambda function F
|
||||
template <typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F // signature: F(tile_distributed_index<...>)
|
||||
>
|
||||
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
|
||||
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
|
||||
|
||||
f(dstr_idx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
945
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
945
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,945 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
|
||||
{
|
||||
return Transforms::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of number and index_t
|
||||
return container_reduce(lengths, multiplies{}, number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == IDimHidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
less<index_t>,
|
||||
equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
|
||||
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
|
||||
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
// may be index_t or number<>
|
||||
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
|
||||
|
||||
// FIXME: this logic is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
|
||||
{
|
||||
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
}
|
||||
|
||||
template <index_t IDimTop>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimBottom>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t
|
||||
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
|
||||
number<ndim_top_>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
|
||||
number<ndim_bottom_>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TopIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = get_transforms().at(itran);
|
||||
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::size(), 1>{}([&](auto i) {
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
|
||||
});
|
||||
|
||||
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
|
||||
{
|
||||
auto vector_lengths = guaranteed_vector_lengths;
|
||||
auto vector_strides = guaranteed_vector_strides;
|
||||
|
||||
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto up_guaranteed_vector_lengths =
|
||||
get_container_subset(guaranteed_vector_lengths, up_dims);
|
||||
const auto up_guaranteed_vector_strides =
|
||||
get_container_subset(guaranteed_vector_strides, up_dims);
|
||||
|
||||
// only need type of transform
|
||||
auto [up_vector_lengths, up_vector_strides] =
|
||||
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
|
||||
get_container_subset(vector_lengths, low_dims),
|
||||
get_container_subset(vector_strides, low_dims));
|
||||
|
||||
if constexpr(up_dims.size() > 0)
|
||||
{
|
||||
for(index_t i = 0; i < up_dims.size(); ++i)
|
||||
{
|
||||
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
|
||||
? up_guaranteed_vector_lengths[i]
|
||||
: up_vector_lengths[i];
|
||||
|
||||
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
|
||||
? up_guaranteed_vector_strides[i]
|
||||
: up_vector_strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
|
||||
set_container_subset(vector_strides, up_dims, up_vector_strides);
|
||||
});
|
||||
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
|
||||
number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
|
||||
|
||||
return tensor_adaptor<remove_cvref_t<Transforms>,
|
||||
remove_cvref_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_tensor_adaptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
{
|
||||
template <typename I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
|
||||
{
|
||||
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
|
||||
return number<Tran::get_num_of_upper_dimension()>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorAdaptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
|
||||
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_top_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension top id to hidden id
|
||||
[](auto low_dim_top_id) constexpr {
|
||||
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
|
||||
},
|
||||
low_dim_top_ids);
|
||||
},
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::size();
|
||||
|
||||
// upper dimension's hidden idss
|
||||
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
|
||||
|
||||
constexpr auto up_dim_numbers =
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr auto new_top_dim_hidden_ids =
|
||||
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms =
|
||||
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
|
||||
|
||||
return tensor_adaptor<
|
||||
remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
|
||||
TensorAdaptor1::get_num_of_bottom_dimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id_ =
|
||||
min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()
|
||||
[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_low_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_up_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// Macro function
|
||||
// construct constexpr tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pass_through_transform(low_len); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
auto left_pad = meta_data.template pop<index_t>(pos); \
|
||||
auto right_pad = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pad_transform(low_len, left_pad, right_pad); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
auto coefficients = \
|
||||
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_embed_transform(up_lens, coefficients); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
|
||||
\
|
||||
return make_merge_transform(low_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_unmerge_transform(up_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_replicate_transform(up_lens); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
|
||||
// Macro function
|
||||
// construct static tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
return make_pass_through_transform(number<low_len>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
constexpr index_t left_pad = \
|
||||
meta_data.template get<index_t>(sizeof(low_len)); \
|
||||
\
|
||||
constexpr index_t right_pad = \
|
||||
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
|
||||
\
|
||||
return make_pad_transform( \
|
||||
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
constexpr auto coefficients = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
|
||||
\
|
||||
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
|
||||
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
constexpr auto low_lens = \
|
||||
meta_data.template get<array<index_t, num_low_dim>>(0); \
|
||||
\
|
||||
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor_coordinate
|
||||
{
|
||||
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
|
||||
|
||||
//
|
||||
HiddenIndex idx_hidden_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
|
||||
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize visible index
|
||||
set_container_subset(idx_hidden, top_dim_ids, idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return tensor_adaptor_coordinate<ndim_hidden,
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true,
|
||||
typename Adaptor,
|
||||
typename AdaptorCoord,
|
||||
typename TopIndex,
|
||||
typename BottomIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top,
|
||||
BottomIndex& idx_diff_bottom)
|
||||
{
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
|
||||
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
|
||||
|
||||
// judge whether calculation of lower diff is needed for each transform
|
||||
// use index_t for boolean type
|
||||
auto do_transforms = make_zero_multi_index<ntransform>();
|
||||
|
||||
if constexpr(JudgeDoTransforms)
|
||||
{
|
||||
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// decide do_transform by checkout non-zero index diff components
|
||||
multi_index<ndim_top> non_zero_diff_pick_top;
|
||||
|
||||
static_for<0, ndim_top, 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
|
||||
|
||||
set_container_subset(
|
||||
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> non_zero_diff_pick_low;
|
||||
|
||||
// if any of upper index diff components is non-zero, then
|
||||
// 1) Need to do this transform
|
||||
// 2) all components of lower index diff will assume to be non-zero and need to be
|
||||
// computed
|
||||
const bool idx_diff_up_has_non_zero = container_reduce(
|
||||
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
|
||||
|
||||
do_transforms(itran) = idx_diff_up_has_non_zero;
|
||||
|
||||
static_for<0, dims_low.size(), 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
|
||||
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
|
||||
}
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize top index diff
|
||||
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
// update top index
|
||||
auto idx_hidden_pick_top =
|
||||
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
|
||||
|
||||
idx_hidden_pick_top += idx_diff_top;
|
||||
|
||||
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(do_transforms[itran])
|
||||
{
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
|
||||
auto idx_low = get_container_subset(idx_hidden, dims_low);
|
||||
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_diff_low;
|
||||
|
||||
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
|
||||
|
||||
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
}
|
||||
});
|
||||
|
||||
// set bottom index diff
|
||||
idx_diff_bottom =
|
||||
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top)
|
||||
{
|
||||
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
|
||||
multi_index<ndim_bottom> tmp;
|
||||
|
||||
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdaptorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
|
||||
const AdaptorCoord& coord)
|
||||
{
|
||||
bool valid = true;
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
const auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
|
||||
const auto tran = adaptor.get_transforms().at(itran);
|
||||
|
||||
// check validity, only if current transformation does not always has a valid mapping
|
||||
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
|
||||
{
|
||||
const auto idx_up = get_container_subset(
|
||||
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
|
||||
|
||||
// Comment: using valid = valid && .. will result in weird control flow in ISA
|
||||
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
|
||||
}
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdpatorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
|
||||
const AdpatorCoord& coord)
|
||||
{
|
||||
// check top index
|
||||
const auto& idx_top = coord.get_top_index();
|
||||
|
||||
bool is_top_index_valid = true;
|
||||
|
||||
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
|
||||
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
|
||||
is_top_index_valid =
|
||||
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
|
||||
});
|
||||
|
||||
// check other hidden index
|
||||
return is_top_index_valid &&
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename TopDimensionHiddenIds>
|
||||
struct tensor_coordinate
|
||||
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
|
||||
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: Base{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
// construct from TensorAdaptorCoordinte base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
|
||||
{
|
||||
return Base::get_bottom_index()[number<0>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
|
||||
{
|
||||
return Base::get_hidden_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
|
||||
};
|
||||
|
||||
template <typename TensorDesc, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
|
||||
|
||||
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
|
||||
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
|
||||
adaptor_coord};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
|
||||
{
|
||||
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
467
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
467
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
@@ -0,0 +1,467 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// TopDimensionHiddenIds> : sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths_,
|
||||
typename GuaranteedVectorSrides_>
|
||||
struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>;
|
||||
|
||||
using ElementSpaceSizeType = ElementSpaceSize;
|
||||
|
||||
constexpr static index_t ntransform_ = Base::get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
|
||||
|
||||
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
|
||||
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
|
||||
|
||||
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
|
||||
GuaranteedVectorStrides::size() == ndim_hidden_,
|
||||
"wrong! inconsistent # of hidden dimensions");
|
||||
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{transforms}, element_space_size_{element_space_size}
|
||||
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
// construct from tensor_adaptor base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{adaptor}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return Base::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
|
||||
{
|
||||
return Base::get_top_dimension_length(idim);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
{
|
||||
return element_space_size_;
|
||||
}
|
||||
|
||||
template <typename Idx>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
|
||||
{
|
||||
return Base::calculate_bottom_index(idx)[number<0>{}];
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
|
||||
{
|
||||
return Base::get_transforms();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_lower_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_upper_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return Base::get_top_dimension_hidden_ids();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return Base::is_known_at_compile_time() &&
|
||||
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
|
||||
{
|
||||
return Base::get_top_dimension_safe_vector_length_strides(
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
const ElementSpaceSize& element_space_size)
|
||||
{
|
||||
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
|
||||
|
||||
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type>{
|
||||
adaptor, element_space_size};
|
||||
}
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
const auto element_space_size = old_tensor_desc.get_element_space_size();
|
||||
|
||||
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldTopIdss{},
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
|
||||
|
||||
using NewGuaranteedVectorLengths = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorLengths,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
using NewGuaranteedVectorStrides = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorStrides,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
return tensor_descriptor<
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
NewGuaranteedVectorLengths,
|
||||
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < Lengths::size() - 1)
|
||||
{
|
||||
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*
|
||||
* These functions create naive tensor descriptor
|
||||
*/
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size =
|
||||
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
// tensor descriptor with offset, the offset will not be added into element space size
|
||||
// only have an information of the starting offset, and will impact on offset calculation
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
const offset& os,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_embed_transform(lengths, strides)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_unmerge_transform(lengths));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename Offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
|
||||
const tuple<Lengths...>& lengths,
|
||||
const Offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_unmerge_transform(lengths)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(
|
||||
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
281
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
281
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
@@ -0,0 +1,281 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BufferView_, typename TensorDesc_>
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
using DataType = typename buffer_view::type;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
|
||||
const TensorDesc& desc)
|
||||
: buf_{buffer_view}, desc_{desc}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return TensorDesc::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template get<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
|
||||
{
|
||||
buf_.template set<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
// buf_
|
||||
printf("buf_: ");
|
||||
print(buf_);
|
||||
printf(", ");
|
||||
|
||||
// desc_
|
||||
printf("desc_: ");
|
||||
print(desc_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// member
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
};
|
||||
|
||||
// placeholder type if we want to opt-out a tile view parameter
|
||||
struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view_packed(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <typename OldTensorView,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
|
||||
old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename TileLengths, // tuple<...>
|
||||
typename DoPads> // sequence<bool, bool, ...>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
// transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto idim) {
|
||||
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
|
||||
|
||||
const auto tile_length = tile_lengths[idim];
|
||||
|
||||
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
|
||||
const auto pad_length = new_length - old_length;
|
||||
|
||||
constexpr bool DoPad = DoPads::at(idim);
|
||||
|
||||
const auto transform =
|
||||
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
|
||||
make_pass_through_transform(old_length));
|
||||
|
||||
return transform;
|
||||
},
|
||||
number<num_dim>{});
|
||||
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
759
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
759
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
@@ -0,0 +1,759 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
struct tile_distributed_span
|
||||
{
|
||||
using Impl = sequence<PartialHsLengths...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
// distributed index
|
||||
template <index_t... PartialHsIndices>
|
||||
struct tile_distributed_index
|
||||
{
|
||||
using Impl = sequence<PartialHsIndices...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_span<Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_index<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
|
||||
// should be more elegnat
|
||||
struct tile_distribution
|
||||
{
|
||||
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
|
||||
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
|
||||
|
||||
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
|
||||
"wrong! should be static");
|
||||
|
||||
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
|
||||
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
|
||||
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
|
||||
|
||||
PsYs2XsAdaptor ps_ys_to_xs_;
|
||||
Ys2DDescriptor ys_to_d_;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
#if 0
|
||||
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
|
||||
ps_ys_to_xs_.GetBottomDimensionLengths();
|
||||
#else
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
|
||||
{
|
||||
return ps_ys_to_xs_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
|
||||
{
|
||||
return DstrEncode{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
// Calculate Replication index [R0, R1, ...] based on Partion index
|
||||
// FIXME: very nasty implementation
|
||||
template <typename PartitionIndex>
|
||||
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
|
||||
{
|
||||
static_assert(PartitionIndex::size() == NDimP, "wrong!");
|
||||
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
|
||||
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
|
||||
array<index_t, NDimR> rs_idx;
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
|
||||
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
|
||||
|
||||
// 0-th rh_major is the replicate dimension
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
constexpr index_t adaptor_hidden_id =
|
||||
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
|
||||
|
||||
// fill in
|
||||
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rs_idx;
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
|
||||
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto span_impl = distributed_spans_impl[i];
|
||||
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
|
||||
|
||||
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
|
||||
|
||||
return detail::make_tile_distributed_span(span);
|
||||
},
|
||||
number<NDimX>{});
|
||||
}
|
||||
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
|
||||
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
|
||||
|
||||
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
|
||||
|
||||
ys_idx(i) = dstr_index.impl_[span_minor];
|
||||
});
|
||||
|
||||
return ys_idx;
|
||||
}();
|
||||
|
||||
constexpr index_t ndim_y = NDimY;
|
||||
|
||||
return TO_SEQUENCE(ys_idx_arr, ndim_y);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t NDimMax>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
|
||||
{
|
||||
array<index_t, NDimMax> arr{0};
|
||||
|
||||
for(index_t i = 0; i < iend - ibegin; ++i)
|
||||
{
|
||||
arr(i) = ibegin + i;
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
|
||||
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
|
||||
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
|
||||
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
|
||||
|
||||
// FIXME: increase max value if fail
|
||||
constexpr index_t kMaxNumTransforms = 20;
|
||||
constexpr index_t kMaxMetaDataSize = 128;
|
||||
constexpr index_t kMaxNumDim = 10;
|
||||
|
||||
using Name = coord_transform_enum;
|
||||
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
|
||||
using NumDim = index_t;
|
||||
using Dims = array<index_t, kMaxNumDim>;
|
||||
using Lengths = array<index_t, kMaxNumDim>;
|
||||
|
||||
// Tile Adaptor
|
||||
// bottom dims [x0, x1, x2, ...]
|
||||
// top dims [p0, p1, ..., y0, y1, ...]
|
||||
constexpr index_t ndim_x = HsLengthss::size();
|
||||
|
||||
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
|
||||
|
||||
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
|
||||
|
||||
index_t num_tran = 0;
|
||||
index_t hidden_dim_cnt = ndim_x;
|
||||
|
||||
// this is replicate transform
|
||||
{
|
||||
constexpr index_t ndim_r_minor = RsLengths::size();
|
||||
|
||||
constexpr auto r_minor_lengths = RsLengths{};
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::replicate,
|
||||
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
|
||||
NumDim{0},
|
||||
Dims{},
|
||||
NumDim{ndim_r_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_r_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
};
|
||||
|
||||
// these are Unmerge transforms for X dimesions
|
||||
static_for<0, ndim_x, 1>{}([&trans,
|
||||
&num_tran,
|
||||
&hidden_dim_cnt,
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths =
|
||||
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
|
||||
NumDim{1},
|
||||
Dims{idim_x},
|
||||
NumDim{ndim_h_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_h_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
});
|
||||
|
||||
// transform: P dimensions
|
||||
constexpr index_t ndim_p = Ps2RHssMajor::size();
|
||||
|
||||
Dims hidden_dim_id_ps;
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
|
||||
//
|
||||
index_t hidden_dim_id_p = hidden_dim_cnt++;
|
||||
|
||||
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
|
||||
|
||||
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
|
||||
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
|
||||
|
||||
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
|
||||
|
||||
constexpr index_t ndim_low = p2RHsMajor.size();
|
||||
|
||||
Dims low_dims;
|
||||
Lengths low_lengths;
|
||||
|
||||
for(index_t i = 0; i < ndim_low; ++i)
|
||||
{
|
||||
index_t rh_major = p2RHsMajor[i];
|
||||
index_t rh_minor = p2RHsMinor[i];
|
||||
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
trans(num_tran++) = {coord_transform_enum::merge,
|
||||
MetaData{to_array<index_t, ndim_low>(low_lengths)},
|
||||
NumDim{ndim_low},
|
||||
low_dims,
|
||||
NumDim{1},
|
||||
Dims{hidden_dim_id_p}};
|
||||
});
|
||||
|
||||
constexpr index_t ndim_bottom = ndim_x;
|
||||
|
||||
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
|
||||
|
||||
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
|
||||
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_y = Ys2RHsMajor::size();
|
||||
constexpr index_t ndim_top = ndim_p + ndim_y;
|
||||
|
||||
auto top_dim_ids = hidden_dim_id_ps;
|
||||
|
||||
{
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
const auto ps_ys_to_xs_adaptor_encoding =
|
||||
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
|
||||
|
||||
// descriptor: [y0, y1, ...] to [d]
|
||||
Lengths y_lengths;
|
||||
index_t d_length = 1;
|
||||
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
y_lengths(i) = y_length;
|
||||
d_length *= y_length;
|
||||
}
|
||||
|
||||
auto tran = make_tuple(coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_y>(y_lengths)},
|
||||
NumDim{1},
|
||||
Dims{0},
|
||||
NumDim{ndim_y},
|
||||
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
|
||||
|
||||
const auto ys_to_d_adaptor_encoding = make_tuple(
|
||||
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
|
||||
|
||||
return make_tuple(ps_ys_to_xs_adaptor_encoding,
|
||||
ys_to_d_adaptor_encoding,
|
||||
d_length,
|
||||
rh_major_minor_to_hidden_ids);
|
||||
}
|
||||
|
||||
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
|
||||
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
|
||||
struct tile_distribution_detail
|
||||
{
|
||||
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
|
||||
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
//***********************************************************************************
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(Distribution::NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(Distribution::NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename, typename, typename, index_t>
|
||||
struct reverse_slice_sequence_impl;
|
||||
|
||||
template <index_t x,
|
||||
index_t... xs,
|
||||
index_t m,
|
||||
index_t... ms,
|
||||
index_t id,
|
||||
index_t... ids,
|
||||
index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
sequence<m, ms...>,
|
||||
sequence<id, ids...>,
|
||||
SliceSize>
|
||||
{
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
using dim_slices =
|
||||
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
|
||||
using remaining_slice_sizes = typename sequence_merge<
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
|
||||
typename old_scan::remaining_slice_sizes>::type;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
|
||||
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
|
||||
static constexpr index_t split_idx = std::
|
||||
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
|
||||
};
|
||||
|
||||
template <index_t x, index_t m, index_t id, index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
using remaining_slice_sizes =
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
|
||||
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
|
||||
//
|
||||
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
|
||||
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
|
||||
//
|
||||
// <4, 2, 1, 4, 2> / 4 ->
|
||||
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto reverse_slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
using sliced_type =
|
||||
reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
number<sliced_type::split_idx>{});
|
||||
}
|
||||
|
||||
//
|
||||
// slice tensor from x_dim, result in split in y_dim, not p_dim.
|
||||
// We don't support slice cross p_dim (aka, slice different threads)
|
||||
// also, sliced along y_dim need be the first dim of current dim.
|
||||
// Multiply Y dim before sliced dim does not make sense
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
// subdime
|
||||
// the P dim in the left is 1, means actually not crossing P
|
||||
//
|
||||
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
|
||||
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
|
||||
{
|
||||
// NOTE: this function need to be called under constexpr context,
|
||||
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
|
||||
{
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
|
||||
src_y_prefix_sum[id + 1],
|
||||
1>::type{};
|
||||
|
||||
set_container_subset(
|
||||
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
|
||||
return sliced_h_lens;
|
||||
},
|
||||
typename Encoding::HsLengthss{},
|
||||
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
|
||||
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}
|
||||
();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
|
||||
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
|
||||
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
|
||||
|
||||
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
|
||||
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
|
||||
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
decltype(sliced_h_lengths), // only need to change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
typename Encoding::Ys2RHsMinor>{}),
|
||||
sliced_y_origins,
|
||||
sliced_y_lengths);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
760
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
760
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
@@ -0,0 +1,760 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RsLengths_, // sequence<...>
|
||||
typename HsLengthss_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
|
||||
typename Ys2RHsMajor_, // sequence<...>
|
||||
typename Ys2RHsMinor_> // sequence<...>
|
||||
struct tile_distribution_encoding
|
||||
{
|
||||
using RsLengths = remove_cvref_t<RsLengths_>;
|
||||
using HsLengthss = remove_cvref_t<HsLengthss_>;
|
||||
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
|
||||
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
|
||||
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
|
||||
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
|
||||
|
||||
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
|
||||
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
|
||||
|
||||
static constexpr index_t NDimX = HsLengthss::size();
|
||||
static constexpr index_t NDimP = Ps2RHssMajor::size();
|
||||
static constexpr index_t NDimY = Ys2RHsMajor::size();
|
||||
static constexpr index_t NDimR = RsLengths::size();
|
||||
|
||||
// FIXME: move into detail
|
||||
static constexpr auto rs_lengths_ = RsLengths{};
|
||||
static constexpr auto hs_lengthss_ = HsLengthss{};
|
||||
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
|
||||
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
{
|
||||
// ndim_rh_major_, ndim_span_mainor_
|
||||
static constexpr index_t ndim_rh_major_ = NDimX + 1;
|
||||
static constexpr index_t ndim_span_major_ = NDimX;
|
||||
|
||||
// ndims_rhs_minor_[ndim_rh_major_]
|
||||
static constexpr auto ndims_rhs_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return rs_lengths_.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return hs_lengthss_[i - number<1>{}].size();
|
||||
}
|
||||
},
|
||||
number<ndim_rh_major_>{});
|
||||
|
||||
// max_ndim_rh_minor_
|
||||
static constexpr index_t max_ndim_rh_minor_ =
|
||||
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_lengthss_ =
|
||||
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
|
||||
|
||||
// ys_lengths_
|
||||
static constexpr auto ys_lengths_ = [] {
|
||||
array<index_t, NDimY> ys_lengths_tmp{-1};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major_[i];
|
||||
index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
return ys_lengths_tmp;
|
||||
}();
|
||||
|
||||
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_ys_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_ys_tmp;
|
||||
}();
|
||||
|
||||
// ndims_span_minor_[NDimY]
|
||||
static constexpr auto ndims_span_minor_ = [] {
|
||||
array<index_t, NDimX> ndims_span_minor{0};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_span_minor(span_major)++;
|
||||
}
|
||||
|
||||
return ndims_span_minor;
|
||||
}();
|
||||
|
||||
// max_ndim_span_minor_
|
||||
static constexpr index_t max_ndim_span_minor_ =
|
||||
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
|
||||
{-1}};
|
||||
|
||||
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
|
||||
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
|
||||
|
||||
index_t cnt_ndim_span_minor = 0;
|
||||
|
||||
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
|
||||
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
|
||||
|
||||
if(idim_y >= 0)
|
||||
{
|
||||
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
|
||||
|
||||
cnt_ndim_span_minor++;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_span_minor;
|
||||
}();
|
||||
|
||||
// ys_to_span_major_[NDimY]
|
||||
static constexpr auto ys_to_span_major_ =
|
||||
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
|
||||
|
||||
// ys_to_span_minor_[NDimY]
|
||||
static constexpr auto ys_to_span_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
|
||||
static constexpr auto distributed_spans_lengthss_ = [] {
|
||||
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
|
||||
distributed_spans_lengthss{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t rh_major = ys_to_rhs_major_[i];
|
||||
const index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
|
||||
|
||||
const index_t span_major = rh_major - 1;
|
||||
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
|
||||
|
||||
distributed_spans_lengthss(span_major)(span_minor) = h_length;
|
||||
});
|
||||
|
||||
return distributed_spans_lengthss;
|
||||
}();
|
||||
|
||||
// ndims_distributed_spans_minor_[ndim_span_major_]
|
||||
static constexpr auto ndims_distributed_spans_minor_ = [] {
|
||||
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_distributed_spans_minor(span_major)++;
|
||||
});
|
||||
|
||||
return ndims_distributed_spans_minor;
|
||||
}();
|
||||
|
||||
// does_p_own_r_[NDimP][NDimR]
|
||||
static constexpr auto does_p_own_r_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
does_p_own_r(idim_p)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return does_p_own_r;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<bool, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// ps_over_rs_derivative_[NDimP][NDimR]
|
||||
static constexpr auto ps_over_rs_derivative_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
index_t p_over_rh_derivative = 1;
|
||||
|
||||
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
|
||||
}
|
||||
|
||||
p_over_rh_derivative *= rh_length;
|
||||
});
|
||||
});
|
||||
|
||||
return ps_over_rs_derivative;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<index_t, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
{
|
||||
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
|
||||
|
||||
constexpr auto sorted_dims = typename sorted_idx::type{};
|
||||
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_histogram =
|
||||
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
|
||||
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
|
||||
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
|
||||
{
|
||||
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
|
||||
|
||||
constexpr index_t NDimHMajor = OuterDstr::NDimX;
|
||||
|
||||
using RsLengths =
|
||||
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
|
||||
|
||||
constexpr auto hs_lengthss = generate_tuple(
|
||||
[&](auto i) {
|
||||
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
|
||||
typename InnerDstr::HsLengthss{}[i]);
|
||||
},
|
||||
number<NDimHMajor>{});
|
||||
|
||||
//
|
||||
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
|
||||
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
|
||||
|
||||
// R dimension
|
||||
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
|
||||
|
||||
// Hs dimensions
|
||||
static_for<0, NDimHMajor, 1>{}([&](auto i) {
|
||||
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
|
||||
});
|
||||
|
||||
return rhs_major_2_ndim_outer_rhs_minor_;
|
||||
}();
|
||||
|
||||
// Ps2RHssMinor
|
||||
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
|
||||
[&](auto p) {
|
||||
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
|
||||
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
|
||||
|
||||
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
|
||||
|
||||
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_p_2_rhss_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_p_2_rhss_minor_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
|
||||
},
|
||||
number<InnerDstr::NDimP>{});
|
||||
|
||||
// Ys2RHsMinor
|
||||
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
|
||||
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
|
||||
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
|
||||
|
||||
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_ys_2_rhs_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_ys_2_rhs_minor__;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto ps_2_rhss_major =
|
||||
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
|
||||
|
||||
constexpr auto ps_2_rhss_minor =
|
||||
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
|
||||
|
||||
//
|
||||
constexpr auto ys_2_rhs_major =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
|
||||
|
||||
constexpr auto ys_2_rhs_minor =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
|
||||
|
||||
return tile_distribution_encoding<RsLengths,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// FIXME: increase if fail
|
||||
constexpr index_t max_ndim_r_out = 20;
|
||||
constexpr index_t max_ndim_y_out = 20;
|
||||
|
||||
//
|
||||
constexpr index_t ndim_p = InDstr::NDimP;
|
||||
constexpr index_t ndim_x_in = InDstr::NDimX;
|
||||
constexpr index_t ndim_y_in = InDstr::NDimY;
|
||||
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
|
||||
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
|
||||
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
|
||||
|
||||
// ndims_ps_low
|
||||
constexpr auto ndims_ps_low = generate_array(
|
||||
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
|
||||
|
||||
// is_rh_major_in_for_reduce
|
||||
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
|
||||
{
|
||||
index_t rh_major = reduce_dim_xs_in[i] + 1;
|
||||
|
||||
is_rh_major_in_for_reduce(rh_major) = true;
|
||||
}
|
||||
|
||||
// is_y_in_for_reduce
|
||||
array<bool, ndim_y_in> is_y_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < ndim_y_in; i++)
|
||||
{
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major])
|
||||
{
|
||||
is_y_in_for_reduce(i) = true;
|
||||
}
|
||||
}
|
||||
|
||||
// is_rh_minor_in_for_y_reduce
|
||||
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
if(is_y_in_for_reduce[i])
|
||||
{
|
||||
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
|
||||
// in2out_rh_major
|
||||
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
|
||||
index_t cnt_ndim_rh_major_out = 0;
|
||||
|
||||
for(index_t i = 0; i < ndim_rh_major_in; i++)
|
||||
{
|
||||
if(is_rh_major_in_for_reduce[i])
|
||||
{
|
||||
in2out_rh_major(i) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
in2out_rh_major(i) = cnt_ndim_rh_major_out;
|
||||
|
||||
cnt_ndim_rh_major_out++;
|
||||
}
|
||||
}
|
||||
|
||||
// rs_lengths_out, in2out_rh_minor
|
||||
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
|
||||
|
||||
// loop over input R dim
|
||||
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(i) = InDstr::rs_lengths_[i];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(0)(i) = i;
|
||||
}
|
||||
|
||||
// loop over input H Dim
|
||||
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
|
||||
|
||||
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
|
||||
constexpr auto h_major_in = rh_major_in - I1;
|
||||
|
||||
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major_in])
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
|
||||
|
||||
cnt_ndim_r_out++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_r_out
|
||||
const index_t ndim_r_out = cnt_ndim_r_out;
|
||||
|
||||
// ndims_hs_minor_out, hs_lengthss_out
|
||||
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
|
||||
|
||||
index_t cnt_ndim_x_out = 0;
|
||||
|
||||
static_for<0, ndim_x_in, 1>{}([&](auto i) {
|
||||
if(not is_rh_major_in_for_reduce[i + I1])
|
||||
{
|
||||
// ndims_hs_minor_out
|
||||
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
|
||||
|
||||
// hs_lengthss_out
|
||||
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
|
||||
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
|
||||
|
||||
cnt_ndim_x_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ps_to_rhss_major_out, ps_to_rhss_minor_out
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
|
||||
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
|
||||
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
|
||||
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
|
||||
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
});
|
||||
});
|
||||
|
||||
// ys_to_rhs_major_out, ys_to_rhs_minor_out
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
|
||||
|
||||
index_t cnt_ndim_y_out = 0;
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
if(not is_y_in_for_reduce[i])
|
||||
{
|
||||
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
|
||||
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
|
||||
cnt_ndim_y_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_y_out
|
||||
const index_t ndim_y_out = cnt_ndim_y_out;
|
||||
|
||||
//
|
||||
return make_tuple(ndim_x_out,
|
||||
ndim_p,
|
||||
ndim_y_out,
|
||||
ndim_r_out,
|
||||
ndims_hs_minor_out,
|
||||
ndims_ps_low,
|
||||
rs_lengths_out,
|
||||
hs_lengthss_out,
|
||||
ps_to_rhss_major_out,
|
||||
ps_to_rhss_minor_out,
|
||||
ys_to_rhs_major_out,
|
||||
ys_to_rhs_minor_out);
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
|
||||
|
||||
constexpr index_t ndim_x = impl.template at<0>();
|
||||
constexpr index_t ndim_p = impl.template at<1>();
|
||||
constexpr index_t ndim_y = impl.template at<2>();
|
||||
constexpr index_t ndim_r = impl.template at<3>();
|
||||
constexpr auto ndims_hs_minor = impl.template at<4>();
|
||||
constexpr auto ndims_ps_low = impl.template at<5>();
|
||||
constexpr auto rs_lengths_impl = impl.template at<6>();
|
||||
constexpr auto hs_lengthss_impl = impl.template at<7>();
|
||||
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
|
||||
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
|
||||
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
|
||||
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
|
||||
|
||||
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
|
||||
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
|
||||
constexpr auto ps_to_rhss_major =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ps_to_rhss_minor =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
|
||||
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
|
||||
|
||||
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
263
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
263
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: support tensors with different distribution
|
||||
template <typename InOutElementFunc,
|
||||
typename... InOutDstrTensors,
|
||||
typename = std::enable_if_t<std::conjunction_v<
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
|
||||
InOutDstrTensors&... inout_dstr_tensors)
|
||||
{
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}(
|
||||
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
|
||||
}
|
||||
|
||||
template <typename InElementFunc,
|
||||
typename... InTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&value](auto& x) {
|
||||
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
|
||||
},
|
||||
dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
|
||||
// sub-dword tensor...
|
||||
template <typename DstrTensors, index_t v>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
|
||||
{
|
||||
constexpr index_t tensor_bytes =
|
||||
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0)
|
||||
{
|
||||
using dvec_t = array<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
|
||||
dstr_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t v>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DstrTensors>
|
||||
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
{
|
||||
set_tile(dstr_tensor, 0);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 4 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
|
||||
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
|
||||
// so we prepare an uninitialized variable purposely, and turn off the warning
|
||||
int dummy_old;
|
||||
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
|
||||
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
using i_type = remove_cvref_t<typename InTensor::DataType>;
|
||||
using o_type = remove_cvref_t<OutDataType>;
|
||||
constexpr index_t i_elem_bytes = sizeof(i_type);
|
||||
constexpr index_t o_elem_bytes = sizeof(o_type);
|
||||
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
|
||||
|
||||
constexpr index_t bulk_size =
|
||||
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
|
||||
static_assert(bulk_size != 0);
|
||||
|
||||
using o_bulk_type =
|
||||
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
|
||||
constexpr index_t iters = thread_buffer_size / bulk_size;
|
||||
constexpr index_t rems = thread_buffer_size % bulk_size;
|
||||
|
||||
// cast the sequence per-bulk
|
||||
static_for<0, iters, 1>{}([&](auto i) {
|
||||
union bulk_wrapper
|
||||
{
|
||||
o_bulk_type bulk{};
|
||||
o_type data[bulk_size];
|
||||
} o_bulk;
|
||||
|
||||
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
|
||||
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
|
||||
o_bulk.data[ib.value] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer()
|
||||
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
});
|
||||
|
||||
// TODO: fixme, should use above!
|
||||
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
// o_bulk.data[0] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
// o_bulk.data[1] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
|
||||
});
|
||||
|
||||
static_for<0, rems, 1>{}([&](auto r) {
|
||||
// TODO: introducing local scratch pad?
|
||||
auto idx = number<iters * bulk_size + r>{};
|
||||
out_dstr_tensor.get_thread_buffer().at(idx) =
|
||||
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
#endif
|
||||
} // namespace impl
|
||||
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
|
||||
float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
{
|
||||
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InOutElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
740
include/ck_tile/core/tensor/tile_window.hpp
Normal file
740
include/ck_tile/core/tensor/tile_window.hpp
Normal file
@@ -0,0 +1,740 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
struct tile_window_with_static_distribution
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_static_distribution::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const AdaptorTopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % Traits::ScalarPerVector == 0);
|
||||
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
|
||||
// StaticBuffer<address_space_enum::vgpr,
|
||||
// vector_t,
|
||||
// YElementSize / Traits::ScalarPerVector,
|
||||
// true>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
}
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_with_static_lengths
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = typename BottomTensorView::DataType;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths_>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return tile_window_with_static_lengths<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>>{
|
||||
tensor_view, window_lengths, origin};
|
||||
}
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
|
||||
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
|
||||
step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
208
include/ck_tile/core/utility/functional.hpp
Normal file
208
include/ck_tile/core/utility/functional.hpp
Normal file
@@ -0,0 +1,208 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
|
||||
std::forward<Y>(y).at(number<Js>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
|
||||
std::forward<Y>(y).template at<Js>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
22
include/ck_tile/core/utility/ignore.hpp
Normal file
22
include/ck_tile/core/utility/ignore.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
struct ignore_t
|
||||
{
|
||||
template <typename T>
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck_tile
|
||||
240
include/ck_tile/core/utility/magic_div.hpp
Normal file
240
include/ck_tile/core/utility/magic_div.hpp
Normal file
@@ -0,0 +1,240 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct magic_division32_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= INT32_MAX)
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
|
||||
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// magic number division
|
||||
// This version on works for divisor and dividended between [0, 1 << 16]
|
||||
struct magic_division16_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= (1U << 16));
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint32_t one = 1;
|
||||
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <auto Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// use 32bit version
|
||||
using magic_division = magic_division32_bit_range;
|
||||
|
||||
struct mdiv
|
||||
{
|
||||
// 1 dword -> 3 dword storage
|
||||
uint32_t divisor;
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
|
||||
{
|
||||
divisor = divisor_;
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
|
||||
};
|
||||
|
||||
struct mdiv2
|
||||
{
|
||||
// 1 dword -> 2 dword storage, divisor need compute from runtime
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
58
include/ck_tile/core/utility/random.hpp
Normal file
58
include/ck_tile/core/utility/random.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T, uint32_t seed_>
|
||||
struct prand_generator_t
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
|
||||
};
|
||||
|
||||
// version for fp32
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<float, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
// version for fp16
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<half_t, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
// TODO: use c++20 nontype template with struct to implement this
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
// Macro function
|
||||
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
[a, n] { \
|
||||
static_assert(a.size() >= n, "wrong! out of bound"); \
|
||||
static_assert(n <= 10, "not implemented"); \
|
||||
if constexpr(n == 0) \
|
||||
{ \
|
||||
return ck_tile::sequence<>{}; \
|
||||
} \
|
||||
else if constexpr(n == 1) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 2) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 3) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 4) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 5) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 6) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 7) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 8) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 9) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 10) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
125
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
125
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S: scalar type (or it can be non-scalar type)
|
||||
// NX: # of vector before transpose
|
||||
// NY: # of vector after transpose
|
||||
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
|
||||
template <typename S_, index_t NX, index_t NY>
|
||||
struct transpose_vectors
|
||||
{
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = remove_cvref_t<S_>;
|
||||
|
||||
using VX = array<S, s_per_x>;
|
||||
using VY = array<S, s_per_y>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
constexpr auto I4 = number<4>{};
|
||||
|
||||
if constexpr(sizeof(S) == 2)
|
||||
{
|
||||
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
|
||||
|
||||
using S2 = array<S, 2>; // typename array<S, 2>::type;
|
||||
|
||||
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// 2 16bitx2 data from vx_tuple to be transposed
|
||||
const int32_t x_s2_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
|
||||
const int32_t x_s2_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
constexpr int32_t m1 = 0x07060302;
|
||||
|
||||
// transpose 2x2 16bit
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
|
||||
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
|
||||
|
||||
// 2 16bitx2 data after transposed
|
||||
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
|
||||
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(sizeof(S) == 1)
|
||||
{
|
||||
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
|
||||
|
||||
using S4 = array<S, 4>; // typename array<S, 4>::type;
|
||||
|
||||
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// 4 int8x4 data from vx_tuple
|
||||
const int32_t x_s4_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_2 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_3 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
|
||||
|
||||
// transpose
|
||||
int32_t t_s4_0, t_s4_1;
|
||||
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
|
||||
|
||||
constexpr int32_t m0 = 0x05010400;
|
||||
constexpr int32_t m1 = 0x05040100;
|
||||
constexpr int32_t m2 = 0x07060302;
|
||||
constexpr int32_t m3 = 0x07030602;
|
||||
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
|
||||
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
|
||||
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
|
||||
// 4 int8x4 data from vy_tuple
|
||||
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
|
||||
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
|
||||
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
|
||||
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
95
include/ck_tile/core/utility/type_traits.hpp
Normal file
95
include/ck_tile/core/utility/type_traits.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
using has_is_static = decltype(T::is_static());
|
||||
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_is_static, T>{})
|
||||
return T::is_static();
|
||||
else
|
||||
return std::is_arithmetic<T>::value;
|
||||
}();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
67
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
67
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename F, typename... Fs>
|
||||
struct composes : private composes<F>
|
||||
{
|
||||
template <typename FirstArg, typename... RestArgs>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
|
||||
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
|
||||
}
|
||||
|
||||
private:
|
||||
composes<Fs...> inner_;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
struct composes<F>
|
||||
{
|
||||
static_assert(!std::is_reference_v<F>);
|
||||
|
||||
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg,
|
||||
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return f_(std::forward<Arg>(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
F f_;
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename To>
|
||||
struct saturates
|
||||
{
|
||||
template <typename From>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
|
||||
{
|
||||
return clamp(from,
|
||||
type_convert<From>(numeric<To>::lowest()),
|
||||
type_convert<From>(numeric<To>::max()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user