This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck_tile

View File

@@ -0,0 +1,161 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <stdio.h>
#include <tuple>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
template <typename T_, index_t N_>
struct thread_buffer;
// Usage example: CK_PRINTF<float>{}(tensor);
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF;
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
struct CK_PRINTF<ConvertTo,
str_literal<FMTChars...>,
str_literal<PREFIXChars...>,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
{
if constexpr(std::is_same_v<T, float>)
return std::make_tuple(make_str_literal("%8.3f"), T{});
else if constexpr(std::is_same_v<T, int>)
return std::make_tuple(make_str_literal("%5d"), T{});
else if constexpr(std::is_same_v<T, unsigned int>)
return std::make_tuple(make_str_literal("%5u"), T{});
else if constexpr(sizeof(T) == 1)
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
else if constexpr(sizeof(T) == 2)
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
else if constexpr(sizeof(T) == 4)
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
else
static_assert(false, "Unsupported type");
}
template <typename T>
using default_format_t =
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
template <typename T>
using default_type_t =
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
if constexpr(sizeof...(PREFIXChars) == 0)
return fmt_tid;
else
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
{
constexpr auto lf = make_str_literal("\n");
if constexpr(sizeof...(SUFFIXChars) == 0)
return lf;
else
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>,
Args&&... args) const
{
using FMT1 = std::
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data,
get_thread_id(),
N,
args...,
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
#pragma clang diagnostic pop
}
template <typename T, index_t N, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
}
template <typename... TS, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
Args&&... args) const
{
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
}
};
template <typename T>
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
{
if(get_thread_id() < get_warp_size())
print(std::forward<T>(x));
}
template <typename... Ts>
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
{
using base_t = CK_PRINTF<Ts...>;
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf, std::forward<Args>(args)...);
}
};
/*
* RAII struct which inserts start/end markers into the generated assembly.
*
* Usage:
* - Create an `AsmScopeMarker` object at the beginning of a scope or code block.
* - Its constructor will emit a "CK_ASM_SCOPE_START" marker into the assembly.
* - When the object goes out of scope (end of block, return, exception, etc.),
* the destructor will emit a "CK_ASM_SCOPE_END" marker.
*
* Example:
* {
* [[maybe_unused]] AsmScopeMarker marker; // Emits CK_ASM_SCOPE_START
* // ... code you want to delimit in assembly ...
* } // marker goes out of scope → Emits CK_ASM_SCOPE_END
*
*/
struct AsmScopeMarker
{
// in some future version of clang we might be able to use string_view to customize
CK_TILE_HOST_DEVICE AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_START"); }
CK_TILE_HOST_DEVICE ~AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_END"); }
};
} // namespace ck_tile

View File

@@ -0,0 +1,220 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
namespace ck_tile {
template <typename... Args>
void CK_TILE_ERROR(Args&&... args) noexcept
{
std::ostringstream oss;
(oss << ... << args);
std::cerr << "[CK_TILE_ERROR] " << oss.str() << std::endl;
}
template <typename... Args>
void CK_TILE_INFO(Args&&... args) noexcept
{
std::ostringstream oss;
(oss << ... << args);
std::cout << "[CK_TILE_INFO] " << oss.str() << std::endl;
}
namespace internal {
template <size_t N>
bool is_any_of(const char* const (&names)[N], const std::string& str)
{
return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) {
return str == inner_str;
});
};
template <typename T>
struct ParseEnvVal
{
};
template <>
struct ParseEnvVal<bool>
{
static bool parse_env_var_value(const char* vp)
{
std::string value_env_str{vp};
for(auto& c : value_env_str)
{
if(std::isalpha(c) != 0)
{
c = std::tolower(static_cast<unsigned char>(c));
}
}
if(is_any_of(enabled_names, value_env_str))
{
return true;
}
else if(is_any_of(disabled_names, value_env_str))
{
return false;
}
else
{
throw std::runtime_error("Invalid value for env variable");
}
return false;
}
private:
static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"};
static constexpr const char* disabled_names[] = {
"disable", "disabled", "0", "no", "off", "false"};
};
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
template <>
struct ParseEnvVal<uint64_t>
{
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
};
template <>
struct ParseEnvVal<std::string>
{
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
};
template <typename T>
struct EnvVar
{
private:
T value{};
bool is_unset = true;
public:
const T& GetValue() const { return value; }
bool IsUnset() const { return is_unset; }
void Unset() { is_unset = true; }
void UpdateValue(const T& val)
{
is_unset = false;
value = val;
}
explicit EnvVar(const char* const name, const T& def_val)
{
// NOLINTNEXTLINE (concurrency-mt-unsafe)
const char* vp = std::getenv(name);
if(vp != nullptr) // a value was provided
{
is_unset = false;
value = ParseEnvVal<T>::parse_env_var_value(vp);
}
else // no value provided, use default value
{
value = def_val;
}
}
};
} // end namespace internal
// Static inside function hides the variable and provides
// thread-safety/locking
// Used in global namespace
#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \
namespace ck_tile::env { \
struct name \
{ \
static_assert(std::is_same_v<name, ::ck_tile::env::name>, \
"CK_TILE_DECLARE_ENV* must be used in the global namespace"); \
using value_type = type; \
static ck_tile::internal::EnvVar<type>& Ref() \
{ \
static ck_tile::internal::EnvVar<type> var{#name, default_val}; \
return var; \
} \
}; \
}
#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false)
#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0)
#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "")
#define CK_TILE_ENV(name) \
ck_tile::env::name {}
template <class EnvVar>
inline const std::string& EnvGetString(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
return EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsEnabled(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsDisabled(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline uint64_t EnvValue(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
return EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsUnset(EnvVar)
{
return EnvVar::Ref().IsUnset();
}
template <class EnvVar>
void EnvUnset(EnvVar)
{
EnvVar::Ref().Unset();
}
/// Updates the cached value of an environment variable
template <typename EnvVar, typename ValueType>
void UpdateEnvVar(EnvVar, const ValueType& val)
{
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
EnvVar::Ref().UpdateValue(val);
}
template <typename EnvVar>
void UpdateEnvVar(EnvVar, const std::string_view& val)
{
EnvVar::Ref().UpdateValue(
ck_tile::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(
val.data()));
}
} // namespace ck_tile
// environment variable to enable logging:
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#pragma clang diagnostic pop

View File

@@ -0,0 +1,275 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include <stdint.h>
#include <utility>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
namespace ck_tile {
namespace detail {
struct swallow
{
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
{
}
};
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<sequence<Is...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
swallow{(f(number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
CK_TILE_HOST_DEVICE constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
namespace detail {
template <typename T, T... Is>
struct applier
{
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
(f(number<Is>{}), ...);
}
};
template <int32_t Size> // == sizeof...(Is)
using make_applier = __make_integer_seq<applier, index_t, Size>;
} // namespace detail
template <index_t N>
struct static_for<0, N, 1> : detail::make_applier<N>
{
using detail::make_applier<N>::operator();
};
template <typename... Ts>
struct static_for_product;
template <index_t... Is>
struct static_for_product<static_for<Is...>> : public static_for<Is...>
{
};
template <index_t... Is>
struct static_for_product<sequence<Is...>> : public static_for<Is...>
{
};
template <index_t I>
struct static_for_product<number<I>> : public static_for<0, I, 1>
{
};
template <typename First, typename... Rest>
struct static_for_product<First, Rest...>
{
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
static_for_product<First>{}([=](auto I) {
static_for_product<Rest...>{}([=](auto... Is) { //
f(I, Is...);
});
});
}
};
struct identity
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
// Similar to identity, but takes an additional index parameter as the first argument.
// The index is ignored and only the second argument (value) is forwarded.
// Useful for indexed element-wise operations where the functor signature requires an index.
struct idx_identity
{
template <typename I, typename T>
CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept
{
return std::forward<T>(arg);
}
};
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
}
};
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<sequence<Is...>>
{
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
#endif
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
{
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
std::forward<Y>(y).template at<Js>()...);
#endif
}
};
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
// z = predicate ? x : y
template <bool predicate, typename X, typename Y>
constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
}
else
{
return std::forward<Y>(y);
}
}
} // namespace ck_tile
#pragma clang diagnostic pop

View File

@@ -0,0 +1,173 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_impl
{
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
}
template <class F, class CurrentUnpackIds>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
{
constexpr index_t pack_len = RamainUnpacks::front();
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack);
});
}
};
template <class Orders>
struct static_uford_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_one_shot_impl
{
template <class F, class CurrentUnpackIds, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
}
};
template <class Orders>
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
} // namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template <class Lengths,
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
static_for<0, Lengths::size(), 1>{}(
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
f, make_tuple(sequence<>{}));
}
// this version is friendly for issue function one by one
template <class F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
{
static_assert(i_access < get_num_of_access());
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
decltype(ordered_unpacks),
Orders>{}(
f, make_tuple(sequence<>{}), number<i_access>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <stdexcept>
#include "ck_tile/core/config.hpp"
namespace ck_tile {
inline void
validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
{
if(Layout == "C" && stride < M)
{
throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
std::to_string(stride) + ") must be greater or equal to dim " +
std::to_string(M));
}
if(Layout == "R" && stride < N)
{
throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
std::to_string(stride) + ") must be greater or equal to dim " +
std::to_string(N));
}
}
inline void validate_gemm_stride(std::string a_layout,
std::string b_layout,
std::string c_layout,
int M,
int N,
int K,
int Stride_A,
int Stride_B,
int Stride_C)
{
// set default stride
if(Stride_A <= 0)
Stride_A = (a_layout == "R") ? K : M;
if(Stride_B <= 0)
Stride_B = (b_layout == "R") ? N : K;
if(Stride_C <= 0)
Stride_C = (c_layout == "R") ? N : M;
validate_stride(a_layout, M, K, Stride_A, "Stride_A");
validate_stride(b_layout, K, N, Stride_B, "Stride_B");
validate_stride(c_layout, M, N, Stride_C, "Stride_C");
}
} // namespace ck_tile

View File

@@ -0,0 +1,26 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace ck_tile {
namespace detail {
struct ignore_t
{
template <typename T>
constexpr void operator=(T&&) const noexcept
{
}
template <typename... T>
constexpr void operator()(T&&...) const noexcept
{
}
};
} // namespace detail
inline constexpr detail::ignore_t ignore;
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
namespace ck_tile {
namespace literals {
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline constexpr std::size_t operator""_uz(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
inline constexpr std::size_t operator""_zu(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
} // namespace literals
} // namespace ck_tile

View File

@@ -0,0 +1,257 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace ck_tile {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint64_t tmp_u64 = static_cast<uint64_t>((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
if(__builtin_is_constant_evaluated())
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
if(__builtin_is_constant_evaluated())
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t shift_u32 = 0;
while((1U << shift_u32) < divisor)
{
shift_u32++;
};
uint32_t one = 1;
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);
}
// integral_constant<uint32_t, .>
template <auto Divisor>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (dividend * multiplier) >> 16;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
return (tmp + dividend_u32) >> shift;
}
};
// use 32bit version
using magic_division = magic_division32_bit_range;
struct mdiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
};
struct mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
{
auto tmp = magic_division::calculate_magic_numbers(divisor_);
multiplier = tmp[number<0>{}];
shift = tmp[number<1>{}];
}
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
{
return magic_division::do_magic_division(dividend_, multiplier, shift);
}
CK_TILE_HOST_DEVICE void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,54 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <type_traits>
namespace ck_tile {
namespace detail {
// Helper method to automatically determine compute type
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
template <typename ADataType, typename BDataType>
struct auto_compute_type
{
using LargestInputType = largest_type_t<ADataType, BDataType>;
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
// this logic should change
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
};
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
template <typename ComputeDataType, typename ADataType, typename BDataType>
struct mixed_prec_compute_type
{
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
typename auto_compute_type<ADataType, BDataType>::type,
ComputeDataType>;
};
} // namespace detail
template <typename ComputeDataType, typename ADataType, typename BDataType>
using mixed_prec_compute_type_t =
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
// Helper method to determine compute type, defaulting to input data type
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
// ComputeDataType is used.
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
using mixed_prec_compute_type_from_input_t = std::conditional_t<
is_packed_type_v<ThisDataType>,
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
ThisDataType>;
} // namespace ck_tile

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
namespace ck_tile {
/// @brief Scheduler for persistent GEMM kernels with asynchronous input streaming.
///
/// This structure enables signal-based synchronization for persistent kernels where input data
/// becomes available incrementally. It divides M-dimension tiles into chunks and uses signals
/// to coordinate between the input producer and the kernel consumer.
///
/// Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation:
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
///
/// @par Typical usage pattern:
/// 1. Set tiles_per_chunk_m to group tiles into chunks (e.g., 2 or 4 tiles per chunk)
/// 2. Set tile_idx_pivot_m as offset for chunk calculation
/// 3. Set num_chunks = ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)
/// 4. Allocate chunk_signals array with size = num_chunks
/// 5. Producer sets chunk_signals[i] = 1 when chunk i's data is ready
/// 6. Kernel waits for chunk_signals[chunk_idx] before processing each tile
struct PersistentAsyncInputScheduler
{
/// @brief Number of M-dimension tiles grouped into each chunk.
/// Grouping tiles balances synchronization overhead against input streaming granularity.
/// Set to 0 to disable async scheduling.
uint32_t tiles_per_chunk_m = 0;
/// @brief Device pointer to array of signal values (uint32_t), one per chunk.
/// Producer sets signals to coordinate when input data is ready for consumption.
/// Set to nullptr to disable async scheduling.
uint32_t* chunk_signals = nullptr;
/// @brief Pivot offset for rotating the chunk assignment.
/// Allows shifting which tiles map to which chunks, useful for load balancing.
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
int32_t tile_idx_pivot_m = 0;
/// @brief Number of signal chunks allocated.
/// Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m).
/// Modulo wraparound prevents out-of-bounds access when pivot shifts chunk assignment.
uint32_t num_chunks = 0;
};
} // namespace ck_tile

View File

@@ -0,0 +1,122 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class philox
{
public:
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
: seed(reinterpret_cast<const uint2&>(seed_))
{
ull2* tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset_;
}
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
{
uint4 counter_ = counter;
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
tmp->y = subsequence;
uint2 key_ = seed;
// 7-round philox
#pragma unroll
for(int i = 0; i < 6; i++)
{
counter_ = philox_single_round(counter_, key_);
key_.x += kPhilox10A;
key_.y += kPhilox10B;
}
uint4 output = philox_single_round(counter_, key_);
return output;
}
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
const unsigned long long subsequence) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp_ph.x;
out_tmp[1] = tmp_ph.y;
out_tmp[2] = tmp_ph.z;
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t idx0,
const index_t idx1) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[idx0];
out_tmp[1] = tmp[idx1];
}
CK_TILE_HOST_DEVICE void
get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[idx];
}
private:
struct ull2
{
uint64_t x;
uint64_t y;
};
uint4 counter;
const uint2 seed;
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
{
uint2* res;
unsigned long long tmp;
tmp = static_cast<unsigned long long>(a) * b;
res = reinterpret_cast<uint2*>(&tmp);
return *res;
}
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
{
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
} // namespace ck_tile

View File

@@ -0,0 +1,121 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
namespace str_literal_detail {
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
} // namespace str_literal_detail
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <size_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
str_literal_detail::makeTuple( \
std::make_index_sequence<str_literal_detail::constexpr_strlen(lit_)>()))
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile

View File

@@ -0,0 +1,58 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace ck_tile {
// return 0 if data is not fp16 or fp32
template <typename T, uint32_t seed_>
struct prand_generator_t
{
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
};
// version for fp32
template <uint32_t seed_>
struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = bit_cast<uint32_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
// version for fp16
template <uint32_t seed_>
struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = bit_cast<uint16_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,143 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
namespace ReduceOp {
// y = ReduceOp(y, x);
struct Add
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
template <typename T,
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + x;
}
template <typename T,
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + x_);
}
CK_TILE_HOST_DEVICE static constexpr auto GetAtomic()
{
return memory_operation_enum::atomic_add;
}
};
struct SquareAdd
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
template <typename T,
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
};
struct Max
{
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::lowest();
};
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, x);
if(x > y)
{
changed = true;
}
return new_max;
}
};
struct AbsMax
{
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::zero();
};
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, abs(x));
if(abs(x) > y)
{
changed = true;
}
return new_max;
}
};
} // namespace ReduceOp
} // namespace ck_tile

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
struct AccumulateWithIndex
{
template <typename ReduceOp, typename T, typename IndexType>
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
T& current_value,
IndexType& current_index,
const T& new_value,
const IndexType& new_index) const
{
bool changed = false;
current_value = reduce_func(current_value, new_value, changed);
if(changed)
{
current_index = new_index;
}
else if(new_index < current_index)
{
bool reverse_changed = false;
reduce_func(new_value, current_value, reverse_changed);
if(!reverse_changed)
{
current_index = new_index;
}
}
}
};
struct Accumulate
{
template <typename ReduceOp, typename T>
CK_TILE_HOST_DEVICE void
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
{
current_value = reduce_func(current_value, new_value);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Context, index_t Start = 0, index_t Step = 1>
struct static_counter
{
public:
template <typename Unique>
static constexpr index_t next()
{
return next<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t next()
{
struct Unique
{
};
return next<Unique>(0) * Step + Start;
}
template <typename Unique>
static constexpr index_t current()
{
return current<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t current()
{
struct Unique
{
};
return current<Unique>(0) * Step + Start;
}
private:
template <index_t I>
struct slot
{
_Pragma("GCC diagnostic push");
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
friend constexpr bool slot_allocated(slot<I>);
_Pragma("GCC diagnostic pop");
};
template <index_t I>
struct allocate_slot
{
friend constexpr bool slot_allocated(slot<I>) { return true; }
enum
{
value = I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
static constexpr index_t next(index_t)
{
return next<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template <typename Unique, index_t I = 0>
static constexpr index_t next(double)
{
return allocate_slot<I>::value;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
static constexpr index_t current(index_t)
{
return current<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template <typename Unique, index_t I = Start>
static constexpr index_t current(double)
{
static_assert(I != 0, "You must invoke next() first");
return I - 1;
}
};
namespace impl {
template <int I>
struct static_counter_uniq_;
}
// clang-format off
#define MAKE_SC() \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>>{} \
_Pragma("clang diagnostic pop")
#define MAKE_SC_WITH(start_, step_) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") ck_tile:: \
static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_>{} \
_Pragma("clang diagnostic pop")
#define NEXT_SC(c_) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") c_.next<__COUNTER__>() \
_Pragma("clang diagnostic pop")
#define NEXT_SCI(c_, static_i_) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") \
c_.next<__COUNTER__ + static_i_>() _Pragma("clang diagnostic pop")
// clang-format on
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/container/sequence.hpp"
// TODO: use c++20 nontype template with struct to implement this
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else
// Macro function
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
#define TO_SEQUENCE(a, n) \
[a, n] { \
static_assert(a.size() >= n, "wrong! out of bound"); \
static_assert(n <= 10, "not implemented"); \
if constexpr(n == 0) \
{ \
return ck_tile::sequence<>{}; \
} \
else if constexpr(n == 1) \
{ \
return ck_tile::sequence<a[0]>{}; \
} \
else if constexpr(n == 2) \
{ \
return ck_tile::sequence<a[0], a[1]>{}; \
} \
else if constexpr(n == 3) \
{ \
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
} \
else if constexpr(n == 4) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
} \
else if constexpr(n == 5) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
} \
else if constexpr(n == 6) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
} \
else if constexpr(n == 7) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
} \
else if constexpr(n == 8) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
} \
else if constexpr(n == 9) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
} \
else if constexpr(n == 10) \
{ \
return ck_tile:: \
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
} \
}()
#endif

View File

@@ -0,0 +1,218 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// S: scalar type (or it can be non-scalar type)
// NX: # of vector before transpose
// NY: # of vector after transpose
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
template <typename S_, index_t NX, index_t NY>
struct transpose_vectors
{
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = remove_cvref_t<S_>;
using VX = array<S, s_per_x>;
using VY = array<S, s_per_y>;
struct generic_tag
{
};
struct bytesize2_2x2_tag
{
};
struct bytesize1_4x4_tag
{
};
struct bytesize1_2x2_tag
{
};
CK_TILE_DEVICE static constexpr void
apply_impl(const thread_buffer<VX, NX>& vx_tuple, thread_buffer<VY, NY>& vy_tuple, generic_tag)
{
static_for<0, NY, 1>{}([&](auto iy) {
static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
});
}
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize2_2x2_tag)
{
static_assert(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using S2 = array<S, 2>;
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// 2 16bitx2 data from vx_tuple to be transposed
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const S2 y_s2_0 = bit_cast<S2>(
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
bit_cast<uint32_t>(x_s2_1),
// (A0.B0.C0.D0.A1.B1.C1.D1)[1, 0, 5, 4] = (C1.D1.C0.D0)
0x01'00'05'04));
const S2 y_s2_1 = bit_cast<S2>(
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
bit_cast<uint32_t>(x_s2_1),
// (A0.B0.C0.D0.A1.B1.C1.D1)[3, 2, 7, 6] = (A1.B1.A0.B0)
0x03'02'07'06));
// write transposed 2x2 result:
// write (C1.D1.C0.D0)
vy_tuple(iy).set_as(ix / I2, y_s2_0);
// write (A1.B1.A0.B0)
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
});
});
}
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize1_4x4_tag)
{
static_assert(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0, "wrong!");
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
using S4 = array<S, 4>;
// loop over 4x4 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// read A0.B0.C0.D0
const S4 x_s4_0 = vx_tuple[ix].template get_as<S4>(iy / I4);
// read A1.B1.C1.D1
const S4 x_s4_1 = vx_tuple[ix + I1].template get_as<S4>(iy / I4);
// read A2.B2.C2.D2
const S4 x_s4_2 = vx_tuple[ix + I2].template get_as<S4>(iy / I4);
// read A3.B3.C3.D3
const S4 x_s4_3 = vx_tuple[ix + I3].template get_as<S4>(iy / I4);
// (A1.B1.C1.D1.A0.B0.C0.D0)[5, 1, 4, 0] = (C1.C0.D1.D0)
uint32_t t_s4_0 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x05'01'04'00);
// (A3.B3.C3.D3.A2.B2.C2.D2)[5, 1, 4, 0] = (C3.C2.D3.D2)
uint32_t t_s4_1 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x05'01'04'00);
// (C3.C2.D3.D2.C1.C0.D1.D0)[5, 4, 1, 0] = (D3.D2.D1.D0)
const S4 y_s4_0 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
// (C3.C2.D3.D2.C1.C0.D1.D0)[7, 6, 3, 2] = (C3.C2.C1.C0)
const S4 y_s4_1 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
// (A1.B1.C1.D1.A0.B0.C0.D0)[7, 3, 6, 2] = (A1.A0.B1.B0)
t_s4_0 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x07'03'06'02);
// (A3.B3.C3.D3.A2.B2.C2.D2)[7, 3, 6, 2] = (A3.A2.B3.B2)
t_s4_1 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x07'03'06'02);
// (A3.A2.B3.B2.A1.A0.B1.B0)[5, 4, 1, 0] = (B3.B2.B1.B0)
const S4 y_s4_2 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
// (A3.A2.B3.B2.A1.A0.B1.B0)[7, 6, 3, 2] = (A3.A2.A1.A0)
const S4 y_s4_3 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
// write transposed 4x4 result:
// write (D3.D2.D1.D0)
vy_tuple(iy).set_as(ix / I4, y_s4_0);
// write (C3.C2.C1.C0)
vy_tuple(iy + I1).set_as(ix / I4, y_s4_1);
// write (B3.B2.B1.B0)
vy_tuple(iy + I2).set_as(ix / I4, y_s4_2);
// write (A3.A2.A1.A0)
vy_tuple(iy + I3).set_as(ix / I4, y_s4_3);
});
});
}
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize1_2x2_tag)
{
static_assert(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using S2 = array<S, 2>;
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// read A0.B0
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
// read A1.B1
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
// v_perm_b32: pick 4 bytes from 8 bytes in (input0.input1) using the mask
const S2 y_s2_0 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 0, 4] = (00.00.B1.B0)
0x0C'0C'00'04)));
const S2 y_s2_1 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 1, 5] = (00.00.A1.A0)
0x0C'0C'01'05)));
// write transposed 2x2 result:
// write (B1.B0)
vy_tuple(iy).set_as(ix / I2, y_s2_0);
// write (A1.A0)
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
});
});
}
CK_TILE_DEVICE static constexpr auto tag_dispatch()
{
if constexpr(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0)
{
return bytesize2_2x2_tag{};
}
else if constexpr(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0)
{
return bytesize1_4x4_tag{};
}
else if constexpr(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0)
{
return bytesize1_2x2_tag{};
}
else
{
return generic_tag{};
}
}
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple) const
{
apply_impl(vx_tuple, vy_tuple, tag_dispatch());
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,207 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <tuple>
#include <type_traits>
#include <stdint.h>
namespace ck_tile {
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename From, typename To>
struct copy_const
{
static_assert(!std::is_const_v<From>);
using type = To;
};
template <typename From, typename To>
struct copy_const<const From, To>
{
using type = std::add_const_t<typename copy_const<From, To>::type>;
};
template <typename From, typename To>
using copy_const_t = typename copy_const<From, To>::type;
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
namespace impl {
template <typename T>
using has_is_static = decltype(T::is_static());
template <typename T>
struct is_static_impl
{
static constexpr bool value = []() {
if constexpr(is_detected<has_is_static, T>{})
return T::is_static();
else
return std::is_arithmetic<T>::value;
}();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template <
typename PY,
typename PX,
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
template <typename CompareTo, typename... Rest>
struct is_any_of : std::false_type
{
};
template <typename CompareTo, typename FirstType>
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
{
};
template <typename CompareTo, typename FirstType, typename... Rest>
struct is_any_of<CompareTo, FirstType, Rest...>
: std::integral_constant<bool,
std::is_same<CompareTo, FirstType>::value ||
is_any_of<CompareTo, Rest...>::value>
{
};
/**
* @brief Helper to check if a value is in a list of values
* @tparam T The type of the search value
* @tparam Ts The types of the search list values
* @param search The value to search for
* @param searchList The list of values to search in
* @return true if the search value is in the search list, false otherwise
*/
template <typename T, typename... Ts>
// TODO: c++20 requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >= 1))
CK_TILE_HOST_DEVICE static constexpr bool is_any_value_of(T search, Ts... searchList)
{
static_assert((std::is_convertible<Ts, T>::value && ...),
"All searchList values must be convertible to the type of search");
static_assert(sizeof...(Ts) >= 1, "searchList must contain at least one value");
return ((search == static_cast<T>(searchList)) || ...);
}
// Helper to check if a type is a specialization of a given template
template <typename Test, template <typename...> class RefTemplate>
struct is_specialization_of : std::false_type
{
};
template <template <typename...> class RefTemplate, typename... Args>
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
{
};
// Helper to get a tuple element or default type
namespace detail {
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch
{
using type = DefaultType;
};
template <std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
{
using type = std::tuple_element_t<Idx, Tuple>;
};
} // namespace detail
template <typename Tuple_, std::size_t Idx, typename DefaultType>
struct tuple_element_or_default
{
using Tuple = remove_cvref_t<Tuple_>;
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
using type = typename detail::
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
};
template <typename Tuple_, std::size_t Idx, typename DefaultType>
using tuple_element_or_default_t =
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
// Helper struct to determine if a type is packed (more than 1 element per byte)
template <typename T>
struct is_packed_type
{
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
};
template <typename T>
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
// Helper definition to take the largest sizes type
template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
} // namespace ck_tile

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename F, typename... Fs>
struct composes : private composes<F>
{
template <typename FirstArg, typename... RestArgs>
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
{
}
template <typename Arg>
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
{
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
}
private:
composes<Fs...> inner_;
};
template <typename F>
struct composes<F>
{
static_assert(!std::is_reference_v<F>);
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
{
}
template <typename Arg,
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
{
return f_(std::forward<Arg>(arg));
}
private:
F f_;
};
template <class... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_composes(Ts&&... ts)
{
return composes<remove_cvref_t<Ts>...>{std::forward<Ts>(ts)...};
}
template <typename SaturateType>
struct saturates
{
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template <typename AccType>
CK_TILE_HOST_DEVICE constexpr auto
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
{
return clamp(a_,
type_convert<AccType>(numeric<SaturateType>::lowest()),
type_convert<AccType>(numeric<SaturateType>::max()));
}
};
} // namespace ck_tile