now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -336,8 +336,8 @@ struct buffer_store<2>
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
@@ -468,9 +468,9 @@ struct buffer_store_if<2>
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
@@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
}
// buffer load i8
CK_TILE_DEVICE int8_t
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE int8x2_t
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE int8x4_t
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE int16_t
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE int16x2_t
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE int16x4_t
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE int32_t
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE int32x2_t
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE int32x4_t
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE fp16_t
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE fp16x2_t
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE fp16x4_t
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE float
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE fp32x2_t
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE fp32x4_t
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
@@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE double
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
@@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<fp16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<bf16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<bf16x2_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<bf16x4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<0>{}],
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<1>{}],
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bf16_t),
@@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
}
// Direct loads from global to LDS.
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,

View File

@@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
CK_TILE_DEVICE void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck_tile

View File

@@ -9,6 +9,9 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace ck_tile {
@@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
} // namespace ck_tile

View File

@@ -9,13 +9,15 @@
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST __host__
#define CK_TILE_DEVICE __device__
#define CK_TILE_HOST_DEVICE __host__ __device__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
@@ -122,7 +124,7 @@
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
@@ -132,3 +134,7 @@
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif

View File

@@ -21,7 +21,12 @@ struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
@@ -44,18 +49,24 @@ struct array
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
template <typename Y>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = c;
}
template <typename ArrayType>
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
{
static_assert(ArrayType::size() == size(), "wrong! size not the same");
for(auto i = 0; i < size(); i++)
data[i] = o.data[i];
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
@@ -147,10 +158,10 @@ struct vector_traits<array<T, N>>
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
}
// make empty array

View File

@@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
@@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
}()
#endif
} // namespace ck_tile

View File

@@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
for(index_t i = 0; i < Seq::size(); ++i)
{
result = f(result, Seq::get(i));
result = f(result, Seq::at(i));
}
return result;
@@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag || f(Seq::get(i));
flag = flag || f(Seq::at(i));
}
return flag;
@@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag && f(Seq::get(i));
flag = flag && f(Seq::at(i));
}
return flag;
@@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// template <index_t... Is>
// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple<number<Is>...>)
// {
// return sequence<Is...>{};
// }
template <class... T>
struct tuple;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
{
return sequence<Is...>{};
}
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>

View File

@@ -139,6 +139,26 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
return !(a == b);
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
@@ -237,21 +257,21 @@ template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
@@ -490,58 +510,58 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) \
_Pragma("clang diagnostic pop")
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif

View File

@@ -4,44 +4,36 @@
#pragma once
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
CK_TILE_HOST_DEVICE \
bool operator==(const type_& x, const type_& y) \
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator!=(const type_& x, const type_& y) \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<(const type_& x, const type_& y) \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<=(const type_& x, const type_& y) \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>(const type_& x, const type_& y) \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>=(const type_& x, const type_& y) \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
type_ operator+(const type_& x, const type_& y) \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x) \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
@@ -49,66 +41,55 @@
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x, const type_& y) \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator*(const type_& x, const type_& y) \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator/(const type_& x, const type_& y) \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_& operator+=(type_& x, const type_& y) \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator-=(type_& x, const type_& y) \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator*=(type_& x, const type_& y) \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator/=(type_& x, const type_& y) \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator++(type_& x) \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator--(type_& x) \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_ operator++(type_& x, int) \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator--(type_& x, int) \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \

View File

@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x);
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
return float_to_bf16_truc_raw(f);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
{
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
return u.fp32;
}
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
{
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
}
};
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
// convert to bitwise
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
// check if x is 0.0
if(x_bitwise == 0)
return 0;
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
}
else
{
return signed_inf;
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
return __builtin_bit_cast(
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
}
template <typename X, typename Y, bool negative_zero_nan>
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr X nan_code = 0x80;
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
if(x_raw == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
uint32_t sign = x_raw >> (in_exp + in_mant);
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
int exponent = (x_raw & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if constexpr(negative_zero_nan)
{
if(x == nan_code)
if(x_raw == nan_code)
return NaN;
}
else
{
if(x == nan_code)
if(x_raw == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval = x_raw;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
};
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include <hip/hip_fp16.h>
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x);
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
@@ -46,6 +53,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
@@ -61,6 +72,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to double
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x)
{
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
CK_TILE_HOST_DEVICE
half_t double_to_fp16(const double& x) { return half_t{x}; }
// limits
template <class T>
struct numeric_limits;
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
};
// arithmetic
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
return y;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
return y;
}
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }

View File

@@ -14,8 +14,9 @@ struct constant
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <typename T, T v>

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace ck_tile {
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - clz(x);
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
@@ -275,7 +276,7 @@ struct log2e<float>
};
template <typename T = double>
inline constexpr T log2e_v = log2e<T>::value;
constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
@@ -298,16 +299,32 @@ bool isnan(const float& x)
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
} // namespace ck_tile

View File

@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)

View File

@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bfp16
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));

View File

@@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
@@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVecType>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVec>(
number<in_offset / vec_length_in>{});
});
// transpose
@@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVecType>(
number<out_offset / sizeof(OutVecType)>{},
out_vectors[i].template get_as<OutVecType>()[I0]);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}

View File

@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \

View File

@@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_length();
return Base::get_top_dimension_lengths();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const

View File

@@ -296,7 +296,8 @@ CK_TILE_HOST_DEVICE constexpr auto
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
constexpr auto h_minor_lengths =
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();
@@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
@@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
@@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
@@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"

View File

@@ -7,14 +7,14 @@
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else

View File

@@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace impl {
template <typename T>
struct is_static_impl
{
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
@@ -69,6 +48,36 @@ struct nonesuch
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
namespace impl {
template <typename T>
using has_is_static = decltype(T::is_static());
template <typename T>
struct is_static_impl
{
static constexpr bool value = []() {
if constexpr(is_detected<has_is_static, T>{})
return T::is_static();
else
return std::is_arithmetic<T>::value;
}();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template <
typename PY,