introducing ck_tile! (#1216)

* enable gfx940

* switch between intrinsic mfma routines on mi100/200 and mi300

* fix mfma_int8 on MI300

* disable 2 int8 examples on MI300

* Update cmake-ck-dev.sh

* restore gitignore file

* modify Jenkinsfile to the internal repo

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* add code

* fix build wip

* fix xx

* now can build

* naming

* minor fix

* wip fix

* fix macro for exp2; fix warpgemm a/b in transposedC

* unify as tuple_array

* Update the required Python version to 3.9

* Update executable name in test scripts

* re-structure tuple/array to avoid spill

* Merge function templates

* Fix format

* Add constraint to array<> ctor

* Re-use function

* Some minor changes

* remove wrong code in store_raw()

* fix compile issue in transpose

* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'

* let more integral_constant->constant, and formating

* make sure thread_buffer can be tuple/array

* temp fix buffer_store spill

* not using custom data type by default, now we can have ISA-level same code as opt_padding

* fix compile error, fp8 not ready now

* fix fp8 duplicated move/shift/and/or problem

* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode

* fix scratch in fp8 kernel

* update some readme

* fix merge from upstream

* sync with upstream

* sync upstream again

* sync 22

* remove unused

* fix clang-format

* update README of ck_tile example

* fix several issue

* let python version to be 3.8 as minimal

* remove ck_tile example from default cmake target like all/install/check

* remove mistake

* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

* fix some bug in group-mode masking and codegen. update README

* F8 quantization for FMHA forward (#1224)

* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

* Add element function to fmha api

* Adjust P elementwise function

* Fix bug of elementwise op, our elementwise op is not inout

* Add some elementwise op, prepare to quantization

* Let generate.py can generate different elementwise function

* To prevent compiler issue, remove the elementwise function we have not used.

* Remove f8 pipeline, we should share the same pipeline even in f8

* Remove remove_cvref_t

* Avoid warning

* Fix wrong fp8 QK/KV block gemm setting

* Check fp8 rounding error in check_err()

* Set fp8 rounding error for check_err()

* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode

* 1. codgen the f8 api and kernel
2. f8 host code

* prevent warning in filter mode

* Remove not-in-use elementwise function kargs

* Remove more not-in-use elementwise function kargs

* Small refinements in C++ source files

* Use conditional_t<> to simplify code

* Support heterogeneous argument for binary function types

* Re-use already-existing scales<> functor template

* Fix wrong value produced by saturating

* Generalize the composes<> template

* Unify saturates<> implementation

* Fix type errors in composes<>

* Extend less_equal<>

* Reuse the existing template less_equal<> in check_err()

* Add equal<float> & equal<double>

* Rename check_err() parameter

* Rename check_err() parameter

* Add FIXME comment for adding new macro in future

* Remove unnecessary cast to void

* Eliminate duplicated code

* Avoid dividing api pool into more than 2 groups

* Use more clear variable names

* Use affirmative condition in if stmt

* Remove blank lines

* Donot perfect forwarding in composes<>

* To fix compile error, revert generate.py back to 4439cc107d

* Fix bug of p element function

* Add compute element op to host softmax

* Remove element function in api interface

* Extract user parameter

* Rename pscale and oscale variable

* rename f8 to fp8

* rename more f8 to fp8

* Add pipeline::operator() without element_functor

* 1. Remove deprecated pipeline enum
2. Refine host code parameter

* Use quantization range as input

* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description

* Refine description

* prevent early return

* unify _squant kernel name in cpp, update README

* Adjust the default range.

* Refine error message and bias range

* Add fp8 benchmark and smoke test

* fix fp8 swizzle_factor=4 case

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
carlushuang
2024-04-16 08:27:12 +08:00
committed by GitHub
parent dd34ab6e64
commit db376dd8a4
141 changed files with 30623 additions and 2 deletions

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck_tile

View File

@@ -0,0 +1,208 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
struct swallow
{
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
{
}
};
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<sequence<Is...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
swallow{(f(number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
CK_TILE_HOST_DEVICE constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
struct identity
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
}
};
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<sequence<Is...>>
{
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
#endif
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
{
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
std::forward<Y>(y).template at<Js>()...);
#endif
}
};
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
// z = predicate ? x : y
template <bool predicate, typename X, typename Y>
constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
}
else
{
return std::forward<Y>(y);
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace ck_tile {
namespace detail {
struct ignore_t
{
template <typename T>
constexpr void operator=(T&&) const noexcept
{
}
};
} // namespace detail
inline constexpr detail::ignore_t ignore;
} // namespace ck_tile

View File

@@ -0,0 +1,240 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace ck_tile {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint32_t one = 1;
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <auto Divisor>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
};
// use 32bit version
using magic_division = magic_division32_bit_range;
struct mdiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
};
struct mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace ck_tile {
// return 0 if data is not fp16 or fp32
template <typename T, uint32_t seed_>
struct prand_generator_t
{
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
};
// version for fp32
template <uint32_t seed_>
struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
// version for fp16
template <uint32_t seed_>
struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/container/sequence.hpp"
// TODO: use c++20 nontype template with struct to implement this
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else
// Macro function
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
#define TO_SEQUENCE(a, n) \
[a, n] { \
static_assert(a.size() >= n, "wrong! out of bound"); \
static_assert(n <= 10, "not implemented"); \
if constexpr(n == 0) \
{ \
return ck_tile::sequence<>{}; \
} \
else if constexpr(n == 1) \
{ \
return ck_tile::sequence<a[0]>{}; \
} \
else if constexpr(n == 2) \
{ \
return ck_tile::sequence<a[0], a[1]>{}; \
} \
else if constexpr(n == 3) \
{ \
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
} \
else if constexpr(n == 4) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
} \
else if constexpr(n == 5) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
} \
else if constexpr(n == 6) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
} \
else if constexpr(n == 7) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
} \
else if constexpr(n == 8) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
} \
else if constexpr(n == 9) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
} \
else if constexpr(n == 10) \
{ \
return ck_tile:: \
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
} \
}()
#endif

View File

@@ -0,0 +1,125 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// S: scalar type (or it can be non-scalar type)
// NX: # of vector before transpose
// NY: # of vector after transpose
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
template <typename S_, index_t NX, index_t NY>
struct transpose_vectors
{
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = remove_cvref_t<S_>;
using VX = array<S, s_per_x>;
using VY = array<S, s_per_y>;
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple)
{
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
if constexpr(sizeof(S) == 2)
{
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
using S2 = array<S, 2>; // typename array<S, 2>::type;
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// 2 16bitx2 data from vx_tuple to be transposed
const int32_t x_s2_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int32_t x_s2_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
// 2 16bitx2 data after transposed
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
});
});
}
else if constexpr(sizeof(S) == 1)
{
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
using S4 = array<S, 4>; // typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// 4 int8x4 data from vx_tuple
const int32_t x_s4_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
const int32_t x_s4_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
const int32_t x_s4_3 =
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
// transpose
int32_t t_s4_0, t_s4_1;
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
});
});
}
else
{
static_assert(false, "not implemented");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
namespace impl {
template <typename T>
using has_is_static = decltype(T::is_static());
template <typename T>
struct is_static_impl
{
static constexpr bool value = []() {
if constexpr(is_detected<has_is_static, T>{})
return T::is_static();
else
return std::is_arithmetic<T>::value;
}();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template <
typename PY,
typename PX,
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
} // namespace ck_tile

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename F, typename... Fs>
struct composes : private composes<F>
{
template <typename FirstArg, typename... RestArgs>
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
{
}
template <typename Arg>
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
{
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
}
private:
composes<Fs...> inner_;
};
template <typename F>
struct composes<F>
{
static_assert(!std::is_reference_v<F>);
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
{
}
template <typename Arg,
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
{
return f_(std::forward<Arg>(arg));
}
private:
F f_;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
template <typename To>
struct saturates
{
template <typename From>
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
{
return clamp(from,
type_convert<From>(numeric<To>::lowest()),
type_convert<From>(numeric<To>::max()));
}
};
} // namespace ck_tile