mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Introduce gemm_softmax_gemm to codegen.
This commit is contained in:
@@ -1005,6 +1005,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
@@ -1042,5 +1043,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include "ck/utility/functional2.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace detail {
|
||||
@@ -37,7 +39,7 @@ struct get_carrier<3>
|
||||
{
|
||||
using value_type = uint32_t;
|
||||
|
||||
std::array<std::byte, 3> bytes;
|
||||
Array<ck::byte, 3> bytes;
|
||||
static_assert(sizeof(bytes) <= sizeof(value_type));
|
||||
|
||||
// replacement of host std::copy_n()
|
||||
@@ -61,22 +63,22 @@ struct get_carrier<3>
|
||||
// method to trigger template substitution failure
|
||||
__device__ carrier(const carrier& other) noexcept
|
||||
{
|
||||
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
|
||||
copy_n(other.bytes.begin(), bytes.Size(), bytes.begin());
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ carrier& operator=(value_type value) noexcept
|
||||
{
|
||||
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
|
||||
copy_n(reinterpret_cast<const ck::byte*>(&value), bytes.Size(), bytes.begin());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ operator value_type() const noexcept
|
||||
{
|
||||
std::byte result[sizeof(value_type)];
|
||||
ck::byte result[sizeof(value_type)];
|
||||
|
||||
copy_n(bytes.begin(), bytes.size(), result);
|
||||
copy_n(bytes.begin(), bytes.Size(), result);
|
||||
|
||||
return *reinterpret_cast<const value_type*>(result);
|
||||
}
|
||||
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
|
||||
{
|
||||
constexpr unsigned object_size = sizeof(int64_t);
|
||||
constexpr unsigned second_part_offset = object_size / 2;
|
||||
auto* const from_obj = reinterpret_cast<const std::byte*>(&value);
|
||||
alignas(int64_t) std::byte to_obj[object_size];
|
||||
auto* const from_obj = reinterpret_cast<const ck::byte*>(&value);
|
||||
alignas(int64_t) ck::byte to_obj[object_size];
|
||||
|
||||
using Sgpr = uint32_t;
|
||||
|
||||
@@ -124,15 +126,15 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
|
||||
|
||||
template <
|
||||
typename Object,
|
||||
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
|
||||
typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
|
||||
__device__ auto amd_wave_read_first_lane(const Object& obj)
|
||||
{
|
||||
using Size = unsigned;
|
||||
constexpr Size SgprSize = 4;
|
||||
constexpr Size ObjectSize = sizeof(Object);
|
||||
|
||||
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj);
|
||||
alignas(Object) std::byte to_obj[ObjectSize];
|
||||
auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
|
||||
alignas(Object) ck::byte to_obj[ObjectSize];
|
||||
|
||||
constexpr Size RemainedSize = ObjectSize % SgprSize;
|
||||
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
|
||||
|
||||
@@ -38,6 +38,8 @@ struct Array
|
||||
}
|
||||
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
|
||||
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
|
||||
__host__ __device__ constexpr TData* begin() { return &mData[0]; }
|
||||
__host__ __device__ constexpr TData* end() { return &mData[NSize]; }
|
||||
};
|
||||
|
||||
// empty Array
|
||||
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
|
||||
{
|
||||
using data_type = remove_cvref_t<X>;
|
||||
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
|
||||
return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
|
||||
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
|
||||
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
[&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
[&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
|
||||
@@ -5,8 +5,25 @@
|
||||
|
||||
#include "ck/utility/statically_indexed_array.hpp"
|
||||
|
||||
#ifdef __HIPCC_RTC__
|
||||
/// Definitions from <cstdint>, <cmath> conflict with
|
||||
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
using uint16_t = unsigned short;
|
||||
using float_t = float;
|
||||
#endif // __HIPCC_RTC__
|
||||
|
||||
namespace ck {
|
||||
|
||||
#ifdef __HIPCC_RTC__
|
||||
using byte = unsigned char;
|
||||
#else
|
||||
using std::byte;
|
||||
#endif
|
||||
|
||||
using bhalf_t = ushort;
|
||||
using half_t = _Float16;
|
||||
using int4_t = _BitInt(4);
|
||||
@@ -1060,6 +1077,146 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
|
||||
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
|
||||
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
|
||||
|
||||
#ifdef __HIPCC_RTC__
|
||||
template <typename T>
|
||||
struct NumericLimits;
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int32_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int16_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int8_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<uint32_t>
|
||||
{
|
||||
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<uint16_t>
|
||||
{
|
||||
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<float>
|
||||
{
|
||||
static constexpr unsigned int binary_min = 0x00800000;
|
||||
static constexpr unsigned int binary_max = 0x7F7FFFFF;
|
||||
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
|
||||
static constexpr unsigned int binary_qnan = 0xFFC00001;
|
||||
static constexpr unsigned int binary_inf = 0x7F8000000;
|
||||
|
||||
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
|
||||
|
||||
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<half_t>
|
||||
{
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
static constexpr unsigned short binary_qnan = 0x7FFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct NumericLimits<int4_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
|
||||
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
|
||||
};
|
||||
#else
|
||||
template <typename T>
|
||||
struct NumericLimits
|
||||
{
|
||||
@@ -1151,6 +1308,7 @@ struct NumericLimits<bf8_t>
|
||||
|
||||
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct NumericUtils
|
||||
|
||||
@@ -4,11 +4,26 @@
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
#ifdef __HIPCC_RTC__
|
||||
template <bool B, class T = void>
|
||||
struct enable_if
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct enable_if<true, T>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <bool B, class T = void>
|
||||
using enable_if_t = typename enable_if<B, T>::type;
|
||||
|
||||
#else
|
||||
template <bool B, typename T = void>
|
||||
using enable_if = std::enable_if<B, T>;
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if_t = typename std::enable_if<B, T>::type;
|
||||
|
||||
#endif
|
||||
} // namespace ck
|
||||
|
||||
@@ -183,3 +183,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
|
||||
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
|
||||
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
return ck::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
return ck::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
|
||||
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
|
||||
template <typename F, typename X, typename Y>
|
||||
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
|
||||
std::forward<Y>(y).At(Number<Js>{})...);
|
||||
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
|
||||
ck::forward<Y>(y).At(Number<Js>{})...);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
ck::forward<F>(f), ck::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -9,14 +9,14 @@ namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using value_t = ck::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using value_t = ck::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
@@ -32,12 +32,12 @@ template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
template <typename T>
|
||||
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
|
||||
using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
|
||||
using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
|
||||
using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <ostream>
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
|
||||
{
|
||||
switch(s)
|
||||
@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -30,7 +30,7 @@ struct MagicDivision
|
||||
// WARNING: magic division is only applicable for division inside this range.
|
||||
// You should use the return value of CalculateMagicNumbers, if division is not inside this
|
||||
// range. The "else" logic below is to quiet down run-time error.
|
||||
if(divisor >= 1 && divisor <= INT32_MAX)
|
||||
if(divisor >= 1 && divisor <= ck::NumericLimits<int32_t>::Max())
|
||||
{
|
||||
uint32_t shift = 0;
|
||||
for(shift = 0; shift < 32; ++shift)
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace math {
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
static inline __host__ float abs(float x) { return std::abs(x); };
|
||||
@@ -457,6 +458,7 @@ inline __host__ double expm1<double>(double x)
|
||||
{
|
||||
return std::expm1(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
|
||||
|
||||
@@ -920,5 +922,23 @@ inline __device__ double expm1<double>(double x)
|
||||
return expm1(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T cos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(cosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float cos<float>(float x)
|
||||
{
|
||||
return cosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double cos<double>(double x)
|
||||
{
|
||||
return cos(x);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace ck {
|
||||
|
||||
// Pseudo random number generator
|
||||
// version for fp32
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
|
||||
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<float, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
@@ -23,7 +23,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
|
||||
}
|
||||
|
||||
// version for fp16
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
|
||||
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<half_t, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
@@ -40,12 +40,18 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T,
|
||||
uint32_t seed_t,
|
||||
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
|
||||
ck::enable_if_t<!(ck::is_same<float, T>{} || ck::is_same<half_t, T>{}), bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
#ifdef __HIPCC_RTC__
|
||||
static_cast<void>(id);
|
||||
static_cast<void>(val);
|
||||
static_cast<void>(seed);
|
||||
#else
|
||||
std::ignore = id;
|
||||
std::ignore = val;
|
||||
std::ignore = seed;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/integral_constant.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <ck::index_t... Is>
|
||||
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
|
||||
{
|
||||
@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
|
||||
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -32,7 +32,7 @@ struct TupleElementKeyData
|
||||
template <typename T,
|
||||
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
|
||||
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
|
||||
{
|
||||
return std::forward(x.mData);
|
||||
return ck::forward(x.mData);
|
||||
}
|
||||
|
||||
template <typename Indices, typename... Xs>
|
||||
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
|
||||
!is_same<remove_cvref_t<Y>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Y&& y)
|
||||
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
||||
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
||||
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
|
||||
"wrong! inconsistent size");
|
||||
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
template <typename Y,
|
||||
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
||||
__host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
||||
false>
|
||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
|
||||
@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
|
||||
const Tuple<Y&...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
[&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
[&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
@@ -157,6 +157,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
@@ -165,6 +166,7 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
|
||||
{
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
__host__ __device__ constexpr auto TupleDepth(const T&)
|
||||
|
||||
@@ -8,6 +8,158 @@
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
#ifdef __HIPCC_RTC__
|
||||
template <bool B>
|
||||
using bool_constant = integral_constant<bool, B>;
|
||||
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define CK_BUILTIN_TYPE_TRAIT1(name) \
|
||||
template <class T> \
|
||||
struct name : bool_constant<__##name(T)> \
|
||||
{ \
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define CK_BUILTIN_TYPE_TRAIT2(name) \
|
||||
template <class T, class U> \
|
||||
struct name : bool_constant<__##name(T, U)> \
|
||||
{ \
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define CK_BUILTIN_TYPE_TRAITN(name) \
|
||||
template <class... Ts> \
|
||||
struct name : bool_constant<__##name(Ts...)> \
|
||||
{ \
|
||||
}
|
||||
|
||||
CK_BUILTIN_TYPE_TRAIT1(is_class);
|
||||
CK_BUILTIN_TYPE_TRAIT1(is_pointer);
|
||||
CK_BUILTIN_TYPE_TRAIT1(is_reference);
|
||||
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
|
||||
CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
|
||||
CK_BUILTIN_TYPE_TRAIT2(is_base_of);
|
||||
|
||||
template <class T>
|
||||
struct remove_cv
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct remove_cv<const T> : remove_cv<T>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct remove_cv<volatile T> : remove_cv<T>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct remove_reference
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_reference<T&>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_reference<T&&>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct remove_pointer
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_pointer<T*>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_pointer<T* const>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_pointer<T* volatile>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
template <class T>
|
||||
struct remove_pointer<T* const volatile>
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
|
||||
{
|
||||
return static_cast<T&&>(t_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
|
||||
{
|
||||
return static_cast<T&&>(t_);
|
||||
}
|
||||
|
||||
// TODO
|
||||
template<class T> struct is_const : false_type {};
|
||||
template<class T> struct is_const<const T> : true_type {};
|
||||
template< class T >
|
||||
inline constexpr bool is_const_v = is_const<T>::value;
|
||||
|
||||
template< class T >
|
||||
inline constexpr bool is_reference_v = is_reference<T>::value;
|
||||
|
||||
template<class T> struct remove_const { typedef T type; };
|
||||
template<class T> struct remove_const<const T> { typedef T type; };
|
||||
template< class T >
|
||||
using remove_const_t = typename remove_const<T>::type;
|
||||
|
||||
template< class T >
|
||||
inline constexpr bool is_class_v = is_class<T>::value;
|
||||
|
||||
template< class T >
|
||||
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
|
||||
|
||||
template< class... >
|
||||
using void_t = void;
|
||||
|
||||
using __hip::declval;
|
||||
#else
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
using std::forward;
|
||||
using std::is_base_of;
|
||||
using std::is_class;
|
||||
using std::is_pointer;
|
||||
using std::is_reference;
|
||||
using std::is_trivially_copyable;
|
||||
using std::is_unsigned;
|
||||
using std::remove_cv;
|
||||
using std::remove_pointer;
|
||||
using std::remove_reference;
|
||||
using std::is_const_v;
|
||||
using std::is_reference_v;
|
||||
using std::remove_const_t;
|
||||
using std::is_class_v;
|
||||
using std::is_trivially_copyable_v;
|
||||
using std::void_t;
|
||||
using std::false_type;
|
||||
using std::true_type;
|
||||
using std::declval;
|
||||
#endif
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct is_same : public integral_constant<bool, false>
|
||||
@@ -23,19 +175,19 @@ template <typename X, typename Y>
|
||||
inline constexpr bool is_same_v = is_same<X, Y>::value;
|
||||
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
using remove_reference_t = typename remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
using remove_cv_t = typename remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
using remove_pointer_t = typename remove_pointer<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
inline constexpr bool is_pointer_v = is_pointer<T>::value;
|
||||
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y bit_cast(const X& x)
|
||||
|
||||
@@ -17,10 +17,10 @@ namespace ck {
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
|
||||
|
||||
using NonConstY = std::remove_const_t<Y>;
|
||||
using NonConstX = std::remove_const_t<X>;
|
||||
using NonConstY = ck::remove_const_t<Y>;
|
||||
using NonConstX = ck::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert_sp(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
@@ -166,7 +166,7 @@ template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
@@ -218,7 +218,7 @@ template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
@@ -501,6 +501,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <typename Y, typename X, std::size_t NumElems>
|
||||
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
|
||||
const std::array<X, NumElems>& x)
|
||||
@@ -510,6 +511,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
|
||||
y[i] = type_convert<Y>(x[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Y, typename X, index_t NumElems>
|
||||
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
|
||||
|
||||
Reference in New Issue
Block a user