mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
161
include/ck_tile/core/utility/debug.hpp
Normal file
161
include/ck_tile/core/utility/debug.hpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct thread_buffer;
|
||||
|
||||
// Usage example: CK_PRINTF<float>{}(tensor);
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF;
|
||||
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINTF<ConvertTo,
|
||||
str_literal<FMTChars...>,
|
||||
str_literal<PREFIXChars...>,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return std::make_tuple(make_str_literal("%8.3f"), T{});
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return std::make_tuple(make_str_literal("%5d"), T{});
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return std::make_tuple(make_str_literal("%5u"), T{});
|
||||
else if constexpr(sizeof(T) == 1)
|
||||
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
|
||||
else if constexpr(sizeof(T) == 2)
|
||||
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
|
||||
else if constexpr(sizeof(T) == 4)
|
||||
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
|
||||
else
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
template <typename T>
|
||||
using default_format_t =
|
||||
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
|
||||
template <typename T>
|
||||
using default_type_t =
|
||||
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>,
|
||||
Args&&... args) const
|
||||
{
|
||||
using FMT1 = std::
|
||||
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data,
|
||||
get_thread_id(),
|
||||
N,
|
||||
args...,
|
||||
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(
|
||||
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... TS, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
|
||||
Args&&... args) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
print(std::forward<T>(x));
|
||||
}
|
||||
template <typename... Ts>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
|
||||
{
|
||||
using base_t = CK_PRINTF<Ts...>;
|
||||
|
||||
template <typename T, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* RAII struct which inserts start/end markers into the generated assembly.
|
||||
*
|
||||
* Usage:
|
||||
* - Create an `AsmScopeMarker` object at the beginning of a scope or code block.
|
||||
* - Its constructor will emit a "CK_ASM_SCOPE_START" marker into the assembly.
|
||||
* - When the object goes out of scope (end of block, return, exception, etc.),
|
||||
* the destructor will emit a "CK_ASM_SCOPE_END" marker.
|
||||
*
|
||||
* Example:
|
||||
* {
|
||||
* [[maybe_unused]] AsmScopeMarker marker; // Emits CK_ASM_SCOPE_START
|
||||
* // ... code you want to delimit in assembly ...
|
||||
* } // marker goes out of scope → Emits CK_ASM_SCOPE_END
|
||||
*
|
||||
*/
|
||||
struct AsmScopeMarker
|
||||
{
|
||||
// in some future version of clang we might be able to use string_view to customize
|
||||
CK_TILE_HOST_DEVICE AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_START"); }
|
||||
CK_TILE_HOST_DEVICE ~AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_END"); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
220
include/ck_tile/core/utility/env.hpp
Normal file
220
include/ck_tile/core/utility/env.hpp
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_ERROR(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cerr << "[CK_TILE_ERROR] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_INFO(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cout << "[CK_TILE_INFO] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <size_t N>
|
||||
bool is_any_of(const char* const (&names)[N], const std::string& str)
|
||||
{
|
||||
return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) {
|
||||
return str == inner_str;
|
||||
});
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ParseEnvVal
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct ParseEnvVal<bool>
|
||||
{
|
||||
static bool parse_env_var_value(const char* vp)
|
||||
{
|
||||
std::string value_env_str{vp};
|
||||
|
||||
for(auto& c : value_env_str)
|
||||
{
|
||||
if(std::isalpha(c) != 0)
|
||||
{
|
||||
c = std::tolower(static_cast<unsigned char>(c));
|
||||
}
|
||||
}
|
||||
|
||||
if(is_any_of(enabled_names, value_env_str))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(is_any_of(disabled_names, value_env_str))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Invalid value for env variable");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"};
|
||||
static constexpr const char* disabled_names[] = {
|
||||
"disable", "disabled", "0", "no", "off", "false"};
|
||||
};
|
||||
|
||||
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
|
||||
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
|
||||
template <>
|
||||
struct ParseEnvVal<uint64_t>
|
||||
{
|
||||
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ParseEnvVal<std::string>
|
||||
{
|
||||
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct EnvVar
|
||||
{
|
||||
private:
|
||||
T value{};
|
||||
bool is_unset = true;
|
||||
|
||||
public:
|
||||
const T& GetValue() const { return value; }
|
||||
|
||||
bool IsUnset() const { return is_unset; }
|
||||
|
||||
void Unset() { is_unset = true; }
|
||||
|
||||
void UpdateValue(const T& val)
|
||||
{
|
||||
is_unset = false;
|
||||
value = val;
|
||||
}
|
||||
|
||||
explicit EnvVar(const char* const name, const T& def_val)
|
||||
{
|
||||
// NOLINTNEXTLINE (concurrency-mt-unsafe)
|
||||
const char* vp = std::getenv(name);
|
||||
if(vp != nullptr) // a value was provided
|
||||
{
|
||||
is_unset = false;
|
||||
value = ParseEnvVal<T>::parse_env_var_value(vp);
|
||||
}
|
||||
else // no value provided, use default value
|
||||
{
|
||||
value = def_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
// Static inside function hides the variable and provides
|
||||
// thread-safety/locking
|
||||
// Used in global namespace
|
||||
#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \
|
||||
namespace ck_tile::env { \
|
||||
struct name \
|
||||
{ \
|
||||
static_assert(std::is_same_v<name, ::ck_tile::env::name>, \
|
||||
"CK_TILE_DECLARE_ENV* must be used in the global namespace"); \
|
||||
using value_type = type; \
|
||||
static ck_tile::internal::EnvVar<type>& Ref() \
|
||||
{ \
|
||||
static ck_tile::internal::EnvVar<type> var{#name, default_val}; \
|
||||
return var; \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define CK_TILE_ENV(name) \
|
||||
ck_tile::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
inline const std::string& EnvGetString(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsEnabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsDisabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline uint64_t EnvValue(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsUnset(EnvVar)
|
||||
{
|
||||
return EnvVar::Ref().IsUnset();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
void EnvUnset(EnvVar)
|
||||
{
|
||||
EnvVar::Ref().Unset();
|
||||
}
|
||||
|
||||
/// Updates the cached value of an environment variable
|
||||
template <typename EnvVar, typename ValueType>
|
||||
void UpdateEnvVar(EnvVar, const ValueType& val)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
|
||||
EnvVar::Ref().UpdateValue(val);
|
||||
}
|
||||
|
||||
template <typename EnvVar>
|
||||
void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
{
|
||||
EnvVar::Ref().UpdateValue(
|
||||
ck_tile::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(
|
||||
val.data()));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#pragma clang diagnostic pop
|
||||
275
include/ck_tile/core/utility/functional.hpp
Normal file
275
include/ck_tile/core/utility/functional.hpp
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct applier
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
|
||||
(f(number<Is>{}), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t Size> // == sizeof...(Is)
|
||||
using make_applier = __make_integer_seq<applier, index_t, Size>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N>
|
||||
struct static_for<0, N, 1> : detail::make_applier<N>
|
||||
{
|
||||
using detail::make_applier<N>::operator();
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct static_for_product;
|
||||
template <index_t... Is>
|
||||
struct static_for_product<static_for<Is...>> : public static_for<Is...>
|
||||
{
|
||||
};
|
||||
template <index_t... Is>
|
||||
struct static_for_product<sequence<Is...>> : public static_for<Is...>
|
||||
{
|
||||
};
|
||||
template <index_t I>
|
||||
struct static_for_product<number<I>> : public static_for<0, I, 1>
|
||||
{
|
||||
};
|
||||
template <typename First, typename... Rest>
|
||||
struct static_for_product<First, Rest...>
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
static_for_product<First>{}([=](auto I) {
|
||||
static_for_product<Rest...>{}([=](auto... Is) { //
|
||||
f(I, Is...);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to identity, but takes an additional index parameter as the first argument.
|
||||
// The index is ignored and only the second argument (value) is forwarded.
|
||||
// Useful for indexed element-wise operations where the functor signature requires an index.
|
||||
struct idx_identity
|
||||
{
|
||||
template <typename I, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
|
||||
std::forward<Y>(y).at(number<Js>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
|
||||
std::forward<Y>(y).template at<Js>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
#pragma clang diagnostic pop
|
||||
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// This file should not be included inside tuple.hpp!
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
template <class F, class CurrentUnpackIds>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
|
||||
{
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_one_shot_impl
|
||||
{
|
||||
template <class F, class CurrentUnpackIds, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
|
||||
{
|
||||
constexpr auto r_lens_stride =
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
|
||||
constexpr auto r_upks_stride =
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
|
||||
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// TODO: we may unify static_ford/static_uford in the future
|
||||
//
|
||||
// loop over nd space(sequence) with packs
|
||||
// you must make sure the function passed in has same number of argument
|
||||
//
|
||||
// e.g.
|
||||
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
|
||||
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
|
||||
//
|
||||
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
|
||||
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
|
||||
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
|
||||
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
|
||||
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
|
||||
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
|
||||
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
|
||||
// ...
|
||||
template <class Lengths,
|
||||
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_uford
|
||||
{
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
using L_ = decltype(Lengths{} / Unpacks{});
|
||||
|
||||
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id...)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
|
||||
f, make_tuple(sequence<>{}));
|
||||
}
|
||||
|
||||
// this version is friendly for issue function one by one
|
||||
template <class F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
|
||||
{
|
||||
static_assert(i_access < get_num_of_access());
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
|
||||
decltype(ordered_unpacks),
|
||||
Orders>{}(
|
||||
f, make_tuple(sequence<>{}), number<i_access>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
51
include/ck_tile/core/utility/gemm_validation.hpp
Normal file
51
include/ck_tile/core/utility/gemm_validation.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
inline void
|
||||
validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
|
||||
{
|
||||
if(Layout == "C" && stride < M)
|
||||
{
|
||||
throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
|
||||
std::to_string(stride) + ") must be greater or equal to dim " +
|
||||
std::to_string(M));
|
||||
}
|
||||
if(Layout == "R" && stride < N)
|
||||
{
|
||||
throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
|
||||
std::to_string(stride) + ") must be greater or equal to dim " +
|
||||
std::to_string(N));
|
||||
}
|
||||
}
|
||||
|
||||
inline void validate_gemm_stride(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int Stride_A,
|
||||
int Stride_B,
|
||||
int Stride_C)
|
||||
{
|
||||
// set default stride
|
||||
if(Stride_A <= 0)
|
||||
Stride_A = (a_layout == "R") ? K : M;
|
||||
if(Stride_B <= 0)
|
||||
Stride_B = (b_layout == "R") ? N : K;
|
||||
if(Stride_C <= 0)
|
||||
Stride_C = (c_layout == "R") ? N : M;
|
||||
|
||||
validate_stride(a_layout, M, K, Stride_A, "Stride_A");
|
||||
validate_stride(b_layout, K, N, Stride_B, "Stride_B");
|
||||
validate_stride(c_layout, M, N, Stride_C, "Stride_C");
|
||||
}
|
||||
} // namespace ck_tile
|
||||
26
include/ck_tile/core/utility/ignore.hpp
Normal file
26
include/ck_tile/core/utility/ignore.hpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
struct ignore_t
|
||||
{
|
||||
template <typename T>
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
template <typename... T>
|
||||
constexpr void operator()(T&&...) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck_tile
|
||||
22
include/ck_tile/core/utility/literals.hpp
Normal file
22
include/ck_tile/core/utility/literals.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace literals {
|
||||
// [P0330] Literal Suffix for (signed) size_t (C++23)
|
||||
// ref: https://wg21.link/p0330r8
|
||||
inline constexpr std::size_t operator""_uz(unsigned long long size)
|
||||
{
|
||||
return static_cast<std::size_t>(size);
|
||||
}
|
||||
|
||||
inline constexpr std::size_t operator""_zu(unsigned long long size)
|
||||
{
|
||||
return static_cast<std::size_t>(size);
|
||||
}
|
||||
} // namespace literals
|
||||
} // namespace ck_tile
|
||||
257
include/ck_tile/core/utility/magic_div.hpp
Normal file
257
include/ck_tile/core/utility/magic_div.hpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct magic_division32_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= INT32_MAX)
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint64_t tmp_u64 = static_cast<uint64_t>((1UL << shift_u32) - divisor) << 32;
|
||||
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// magic number division
|
||||
// This version on works for divisor and dividended between [0, 1 << 16]
|
||||
struct magic_division16_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= (1U << 16));
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint32_t one = 1;
|
||||
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <auto Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// use 32bit version
|
||||
using magic_division = magic_division32_bit_range;
|
||||
|
||||
struct mdiv
|
||||
{
|
||||
// 1 dword -> 3 dword storage
|
||||
uint32_t divisor;
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
|
||||
{
|
||||
divisor = divisor_;
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
|
||||
};
|
||||
|
||||
struct mdiv2
|
||||
{
|
||||
// 1 dword -> 2 dword storage, divisor need compute from runtime
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper method to automatically determine compute type
|
||||
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct auto_compute_type
|
||||
{
|
||||
using LargestInputType = largest_type_t<ADataType, BDataType>;
|
||||
|
||||
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
|
||||
// this logic should change
|
||||
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
|
||||
|
||||
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
|
||||
};
|
||||
|
||||
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
struct mixed_prec_compute_type
|
||||
{
|
||||
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
|
||||
typename auto_compute_type<ADataType, BDataType>::type,
|
||||
ComputeDataType>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
using mixed_prec_compute_type_t =
|
||||
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
|
||||
|
||||
// Helper method to determine compute type, defaulting to input data type
|
||||
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
|
||||
// ComputeDataType is used.
|
||||
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
|
||||
using mixed_prec_compute_type_from_input_t = std::conditional_t<
|
||||
is_packed_type_v<ThisDataType>,
|
||||
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
|
||||
ThisDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Scheduler for persistent GEMM kernels with asynchronous input streaming.
|
||||
///
|
||||
/// This structure enables signal-based synchronization for persistent kernels where input data
|
||||
/// becomes available incrementally. It divides M-dimension tiles into chunks and uses signals
|
||||
/// to coordinate between the input producer and the kernel consumer.
|
||||
///
|
||||
/// Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation:
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
///
|
||||
/// @par Typical usage pattern:
|
||||
/// 1. Set tiles_per_chunk_m to group tiles into chunks (e.g., 2 or 4 tiles per chunk)
|
||||
/// 2. Set tile_idx_pivot_m as offset for chunk calculation
|
||||
/// 3. Set num_chunks = ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)
|
||||
/// 4. Allocate chunk_signals array with size = num_chunks
|
||||
/// 5. Producer sets chunk_signals[i] = 1 when chunk i's data is ready
|
||||
/// 6. Kernel waits for chunk_signals[chunk_idx] before processing each tile
|
||||
struct PersistentAsyncInputScheduler
|
||||
{
|
||||
/// @brief Number of M-dimension tiles grouped into each chunk.
|
||||
/// Grouping tiles balances synchronization overhead against input streaming granularity.
|
||||
/// Set to 0 to disable async scheduling.
|
||||
uint32_t tiles_per_chunk_m = 0;
|
||||
|
||||
/// @brief Device pointer to array of signal values (uint32_t), one per chunk.
|
||||
/// Producer sets signals to coordinate when input data is ready for consumption.
|
||||
/// Set to nullptr to disable async scheduling.
|
||||
uint32_t* chunk_signals = nullptr;
|
||||
|
||||
/// @brief Pivot offset for rotating the chunk assignment.
|
||||
/// Allows shifting which tiles map to which chunks, useful for load balancing.
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
int32_t tile_idx_pivot_m = 0;
|
||||
|
||||
/// @brief Number of signal chunks allocated.
|
||||
/// Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m).
|
||||
/// Modulo wraparound prevents out-of-bounds access when pivot shifts chunk assignment.
|
||||
uint32_t num_chunks = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
122
include/ck_tile/core/utility/philox_rand.hpp
Normal file
122
include/ck_tile/core/utility/philox_rand.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
|
||||
uint2 key_ = seed;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 6; i++)
|
||||
{
|
||||
counter_ = philox_single_round(counter_, key_);
|
||||
key_.x += kPhilox10A;
|
||||
key_.y += kPhilox10B;
|
||||
}
|
||||
uint4 output = philox_single_round(counter_, key_);
|
||||
return output;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
|
||||
out_tmp[0] = tmp_ph.x;
|
||||
out_tmp[1] = tmp_ph.y;
|
||||
out_tmp[2] = tmp_ph.z;
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t idx0,
|
||||
const index_t idx1) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[idx0];
|
||||
out_tmp[1] = tmp[idx1];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[idx];
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
tmp = static_cast<unsigned long long>(a) * b;
|
||||
res = reinterpret_cast<uint2*>(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
121
include/ck_tile/core/utility/print.hpp
Normal file
121
include/ck_tile/core/utility/print.hpp
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace str_literal_detail {
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
} // namespace str_literal_detail
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <size_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
str_literal_detail::makeTuple( \
|
||||
std::make_index_sequence<str_literal_detail::constexpr_strlen(lit_)>()))
|
||||
|
||||
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
|
||||
/// can be printed.
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const T&)
|
||||
{
|
||||
static_assert(sizeof(T) == 0,
|
||||
"No print implementation available for this type. Please specialize "
|
||||
"ck_tile::print for your type.");
|
||||
}
|
||||
|
||||
/// Specialization for int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const int& value)
|
||||
{
|
||||
printf("%d", value);
|
||||
}
|
||||
|
||||
/// Specialization for float
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const float& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for double
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const double& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for long
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const long& value)
|
||||
{
|
||||
printf("%ld", value);
|
||||
}
|
||||
|
||||
/// Specialization for unsigned int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
|
||||
{
|
||||
printf("%u", value);
|
||||
}
|
||||
|
||||
/// Specialization for char
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const char& value)
|
||||
{
|
||||
printf("%c", value);
|
||||
}
|
||||
|
||||
/// Specialization for array
|
||||
template <typename T, size_t N>
|
||||
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
|
||||
{
|
||||
printf("[");
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(value[i]); // Recursively call print for each element
|
||||
}
|
||||
printf("]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
58
include/ck_tile/core/utility/random.hpp
Normal file
58
include/ck_tile/core/utility/random.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T, uint32_t seed_>
|
||||
struct prand_generator_t
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
|
||||
};
|
||||
|
||||
// version for fp32
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<float, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
|
||||
{
|
||||
uint32_t x = bit_cast<uint32_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
// version for fp16
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<half_t, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
|
||||
{
|
||||
uint16_t x = bit_cast<uint16_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
143
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
143
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace ReduceOp {
|
||||
// y = ReduceOp(y, x);
|
||||
struct Add
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
float x_ = type_convert<float>(x);
|
||||
|
||||
return type_convert<T>(y_ + x_);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAtomic()
|
||||
{
|
||||
return memory_operation_enum::atomic_add;
|
||||
}
|
||||
};
|
||||
|
||||
struct SquareAdd
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
float x_ = type_convert<float>(x);
|
||||
return type_convert<T>(y_ + (x_ * x_));
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, x);
|
||||
if(x > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::zero();
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, abs(x));
|
||||
if(abs(x) > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
} // namespace ck_tile
|
||||
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
|
||||
struct AccumulateWithIndex
|
||||
{
|
||||
template <typename ReduceOp, typename T, typename IndexType>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
|
||||
T& current_value,
|
||||
IndexType& current_index,
|
||||
const T& new_value,
|
||||
const IndexType& new_index) const
|
||||
{
|
||||
bool changed = false;
|
||||
current_value = reduce_func(current_value, new_value, changed);
|
||||
|
||||
if(changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
else if(new_index < current_index)
|
||||
{
|
||||
bool reverse_changed = false;
|
||||
reduce_func(new_value, current_value, reverse_changed);
|
||||
|
||||
if(!reverse_changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Accumulate
|
||||
{
|
||||
template <typename ReduceOp, typename T>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
|
||||
{
|
||||
current_value = reduce_func(current_value, new_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
134
include/ck_tile/core/utility/static_counter.hpp
Normal file
134
include/ck_tile/core/utility/static_counter.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Context, index_t Start = 0, index_t Step = 1>
|
||||
struct static_counter
|
||||
{
|
||||
public:
|
||||
template <typename Unique>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <typename Unique>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
private:
|
||||
template <index_t I>
|
||||
struct slot
|
||||
{
|
||||
_Pragma("GCC diagnostic push");
|
||||
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
|
||||
friend constexpr bool slot_allocated(slot<I>);
|
||||
_Pragma("GCC diagnostic pop");
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct allocate_slot
|
||||
{
|
||||
friend constexpr bool slot_allocated(slot<I>) { return true; }
|
||||
enum
|
||||
{
|
||||
value = I
|
||||
};
|
||||
};
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t next(index_t)
|
||||
{
|
||||
return next<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
|
||||
// allocate_slot<I>.
|
||||
template <typename Unique, index_t I = 0>
|
||||
static constexpr index_t next(double)
|
||||
{
|
||||
return allocate_slot<I>::value;
|
||||
}
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t current(index_t)
|
||||
{
|
||||
return current<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will return the current counter, or assert
|
||||
// in case next() hasn't been called yet.
|
||||
template <typename Unique, index_t I = Start>
|
||||
static constexpr index_t current(double)
|
||||
{
|
||||
static_assert(I != 0, "You must invoke next() first");
|
||||
|
||||
return I - 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
template <int I>
|
||||
struct static_counter_uniq_;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define MAKE_SC() \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>>{} \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") ck_tile:: \
|
||||
static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_>{} \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#define NEXT_SC(c_) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") c_.next<__COUNTER__>() \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#define NEXT_SCI(c_, static_i_) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wpre-c2y-compat\"") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc2y-extensions\"") \
|
||||
c_.next<__COUNTER__ + static_i_>() _Pragma("clang diagnostic pop")
|
||||
// clang-format on
|
||||
|
||||
// Usage:
|
||||
// constexpr auto c = MAKE_SC()
|
||||
// NEXT_SC(c) // -> constexpr 0
|
||||
// NEXT_SC(c) // -> constexpr 1
|
||||
// NEXT_SC(c) // -> constexpr 2
|
||||
} // namespace ck_tile
|
||||
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
// TODO: use c++20 nontype template with struct to implement this
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
// Macro function
|
||||
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
[a, n] { \
|
||||
static_assert(a.size() >= n, "wrong! out of bound"); \
|
||||
static_assert(n <= 10, "not implemented"); \
|
||||
if constexpr(n == 0) \
|
||||
{ \
|
||||
return ck_tile::sequence<>{}; \
|
||||
} \
|
||||
else if constexpr(n == 1) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 2) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 3) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 4) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 5) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 6) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 7) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 8) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 9) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 10) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
218
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
218
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S: scalar type (or it can be non-scalar type)
|
||||
// NX: # of vector before transpose
|
||||
// NY: # of vector after transpose
|
||||
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
|
||||
template <typename S_, index_t NX, index_t NY>
|
||||
struct transpose_vectors
|
||||
{
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = remove_cvref_t<S_>;
|
||||
|
||||
using VX = array<S, s_per_x>;
|
||||
using VY = array<S, s_per_y>;
|
||||
|
||||
struct generic_tag
|
||||
{
|
||||
};
|
||||
struct bytesize2_2x2_tag
|
||||
{
|
||||
};
|
||||
struct bytesize1_4x4_tag
|
||||
{
|
||||
};
|
||||
struct bytesize1_2x2_tag
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr void
|
||||
apply_impl(const thread_buffer<VX, NX>& vx_tuple, thread_buffer<VY, NY>& vy_tuple, generic_tag)
|
||||
{
|
||||
static_for<0, NY, 1>{}([&](auto iy) {
|
||||
static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple,
|
||||
bytesize2_2x2_tag)
|
||||
{
|
||||
static_assert(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
|
||||
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
using S2 = array<S, 2>;
|
||||
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// 2 16bitx2 data from vx_tuple to be transposed
|
||||
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
|
||||
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
|
||||
|
||||
// transpose 2x2 16bit
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
const S2 y_s2_0 = bit_cast<S2>(
|
||||
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
|
||||
bit_cast<uint32_t>(x_s2_1),
|
||||
// (A0.B0.C0.D0.A1.B1.C1.D1)[1, 0, 5, 4] = (C1.D1.C0.D0)
|
||||
0x01'00'05'04));
|
||||
const S2 y_s2_1 = bit_cast<S2>(
|
||||
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
|
||||
bit_cast<uint32_t>(x_s2_1),
|
||||
// (A0.B0.C0.D0.A1.B1.C1.D1)[3, 2, 7, 6] = (A1.B1.A0.B0)
|
||||
0x03'02'07'06));
|
||||
|
||||
// write transposed 2x2 result:
|
||||
// write (C1.D1.C0.D0)
|
||||
vy_tuple(iy).set_as(ix / I2, y_s2_0);
|
||||
// write (A1.B1.A0.B0)
|
||||
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple,
|
||||
bytesize1_4x4_tag)
|
||||
{
|
||||
static_assert(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0, "wrong!");
|
||||
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
constexpr auto I4 = number<4>{};
|
||||
using S4 = array<S, 4>;
|
||||
// loop over 4x4 tiles and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// read A0.B0.C0.D0
|
||||
const S4 x_s4_0 = vx_tuple[ix].template get_as<S4>(iy / I4);
|
||||
// read A1.B1.C1.D1
|
||||
const S4 x_s4_1 = vx_tuple[ix + I1].template get_as<S4>(iy / I4);
|
||||
// read A2.B2.C2.D2
|
||||
const S4 x_s4_2 = vx_tuple[ix + I2].template get_as<S4>(iy / I4);
|
||||
// read A3.B3.C3.D3
|
||||
const S4 x_s4_3 = vx_tuple[ix + I3].template get_as<S4>(iy / I4);
|
||||
|
||||
// (A1.B1.C1.D1.A0.B0.C0.D0)[5, 1, 4, 0] = (C1.C0.D1.D0)
|
||||
uint32_t t_s4_0 = __builtin_amdgcn_perm(
|
||||
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x05'01'04'00);
|
||||
// (A3.B3.C3.D3.A2.B2.C2.D2)[5, 1, 4, 0] = (C3.C2.D3.D2)
|
||||
uint32_t t_s4_1 = __builtin_amdgcn_perm(
|
||||
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x05'01'04'00);
|
||||
// (C3.C2.D3.D2.C1.C0.D1.D0)[5, 4, 1, 0] = (D3.D2.D1.D0)
|
||||
const S4 y_s4_0 =
|
||||
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
|
||||
// (C3.C2.D3.D2.C1.C0.D1.D0)[7, 6, 3, 2] = (C3.C2.C1.C0)
|
||||
const S4 y_s4_1 =
|
||||
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
|
||||
// (A1.B1.C1.D1.A0.B0.C0.D0)[7, 3, 6, 2] = (A1.A0.B1.B0)
|
||||
t_s4_0 = __builtin_amdgcn_perm(
|
||||
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x07'03'06'02);
|
||||
// (A3.B3.C3.D3.A2.B2.C2.D2)[7, 3, 6, 2] = (A3.A2.B3.B2)
|
||||
t_s4_1 = __builtin_amdgcn_perm(
|
||||
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x07'03'06'02);
|
||||
// (A3.A2.B3.B2.A1.A0.B1.B0)[5, 4, 1, 0] = (B3.B2.B1.B0)
|
||||
const S4 y_s4_2 =
|
||||
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
|
||||
// (A3.A2.B3.B2.A1.A0.B1.B0)[7, 6, 3, 2] = (A3.A2.A1.A0)
|
||||
const S4 y_s4_3 =
|
||||
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
|
||||
|
||||
// write transposed 4x4 result:
|
||||
// write (D3.D2.D1.D0)
|
||||
vy_tuple(iy).set_as(ix / I4, y_s4_0);
|
||||
// write (C3.C2.C1.C0)
|
||||
vy_tuple(iy + I1).set_as(ix / I4, y_s4_1);
|
||||
// write (B3.B2.B1.B0)
|
||||
vy_tuple(iy + I2).set_as(ix / I4, y_s4_2);
|
||||
// write (A3.A2.A1.A0)
|
||||
vy_tuple(iy + I3).set_as(ix / I4, y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple,
|
||||
bytesize1_2x2_tag)
|
||||
{
|
||||
static_assert(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
|
||||
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
using S2 = array<S, 2>;
|
||||
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// read A0.B0
|
||||
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
|
||||
// read A1.B1
|
||||
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
|
||||
|
||||
// v_perm_b32: pick 4 bytes from 8 bytes in (input0.input1) using the mask
|
||||
const S2 y_s2_0 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
|
||||
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
|
||||
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
|
||||
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 0, 4] = (00.00.B1.B0)
|
||||
0x0C'0C'00'04)));
|
||||
|
||||
const S2 y_s2_1 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
|
||||
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
|
||||
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
|
||||
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 1, 5] = (00.00.A1.A0)
|
||||
0x0C'0C'01'05)));
|
||||
|
||||
// write transposed 2x2 result:
|
||||
// write (B1.B0)
|
||||
vy_tuple(iy).set_as(ix / I2, y_s2_0);
|
||||
// write (A1.A0)
|
||||
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto tag_dispatch()
|
||||
{
|
||||
if constexpr(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0)
|
||||
{
|
||||
return bytesize2_2x2_tag{};
|
||||
}
|
||||
else if constexpr(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0)
|
||||
{
|
||||
return bytesize1_4x4_tag{};
|
||||
}
|
||||
else if constexpr(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0)
|
||||
{
|
||||
return bytesize1_2x2_tag{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return generic_tag{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple) const
|
||||
{
|
||||
apply_impl(vx_tuple, vy_tuple, tag_dispatch());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
207
include/ck_tile/core/utility/type_traits.hpp
Normal file
207
include/ck_tile/core/utility/type_traits.hpp
Normal file
@@ -0,0 +1,207 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
template <typename From, typename To>
|
||||
struct copy_const
|
||||
{
|
||||
static_assert(!std::is_const_v<From>);
|
||||
|
||||
using type = To;
|
||||
};
|
||||
|
||||
template <typename From, typename To>
|
||||
struct copy_const<const From, To>
|
||||
{
|
||||
using type = std::add_const_t<typename copy_const<From, To>::type>;
|
||||
};
|
||||
|
||||
template <typename From, typename To>
|
||||
using copy_const_t = typename copy_const<From, To>::type;
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
using has_is_static = decltype(T::is_static());
|
||||
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_is_static, T>{})
|
||||
return T::is_static();
|
||||
else
|
||||
return std::is_arithmetic<T>::value;
|
||||
}();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename CompareTo, typename... Rest>
|
||||
struct is_any_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename CompareTo, typename FirstType>
|
||||
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename CompareTo, typename FirstType, typename... Rest>
|
||||
struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
: std::integral_constant<bool,
|
||||
std::is_same<CompareTo, FirstType>::value ||
|
||||
is_any_of<CompareTo, Rest...>::value>
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper to check if a value is in a list of values
|
||||
* @tparam T The type of the search value
|
||||
* @tparam Ts The types of the search list values
|
||||
* @param search The value to search for
|
||||
* @param searchList The list of values to search in
|
||||
* @return true if the search value is in the search list, false otherwise
|
||||
*/
|
||||
template <typename T, typename... Ts>
|
||||
// TODO: c++20 requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >= 1))
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_any_value_of(T search, Ts... searchList)
|
||||
{
|
||||
static_assert((std::is_convertible<Ts, T>::value && ...),
|
||||
"All searchList values must be convertible to the type of search");
|
||||
static_assert(sizeof...(Ts) >= 1, "searchList must contain at least one value");
|
||||
|
||||
return ((search == static_cast<T>(searchList)) || ...);
|
||||
}
|
||||
|
||||
// Helper to check if a type is a specialization of a given template
|
||||
template <typename Test, template <typename...> class RefTemplate>
|
||||
struct is_specialization_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <template <typename...> class RefTemplate, typename... Args>
|
||||
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to get a tuple element or default type
|
||||
namespace detail {
|
||||
|
||||
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch
|
||||
{
|
||||
using type = DefaultType;
|
||||
};
|
||||
|
||||
template <std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
|
||||
{
|
||||
using type = std::tuple_element_t<Idx, Tuple>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
struct tuple_element_or_default
|
||||
{
|
||||
using Tuple = remove_cvref_t<Tuple_>;
|
||||
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
|
||||
using type = typename detail::
|
||||
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
|
||||
};
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
using tuple_element_or_default_t =
|
||||
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
|
||||
|
||||
// Helper struct to determine if a type is packed (more than 1 element per byte)
|
||||
template <typename T>
|
||||
struct is_packed_type
|
||||
{
|
||||
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
|
||||
|
||||
// Helper definition to take the largest sizes type
|
||||
template <typename ADataType, typename BDataType>
|
||||
using largest_type_t =
|
||||
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
71
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
71
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename F, typename... Fs>
|
||||
struct composes : private composes<F>
|
||||
{
|
||||
template <typename FirstArg, typename... RestArgs>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
|
||||
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
|
||||
}
|
||||
|
||||
private:
|
||||
composes<Fs...> inner_;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
struct composes<F>
|
||||
{
|
||||
static_assert(!std::is_reference_v<F>);
|
||||
|
||||
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg,
|
||||
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return f_(std::forward<Arg>(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
F f_;
|
||||
};
|
||||
|
||||
template <class... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_composes(Ts&&... ts)
|
||||
{
|
||||
return composes<remove_cvref_t<Ts>...>{std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
{
|
||||
// NOTE: this function does not return SaturateType value
|
||||
// it is user's responsiblity to do further cast or not
|
||||
template <typename AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
{
|
||||
return clamp(a_,
|
||||
type_convert<AccType>(numeric<SaturateType>::lowest()),
|
||||
type_convert<AccType>(numeric<SaturateType>::max()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user