mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
now can build
This commit is contained in:
@@ -336,8 +336,8 @@ struct buffer_store<2>
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
static_assert(sizeof(T) == 2);
|
||||
using mbuf_t = short;
|
||||
asm volatile(
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
@@ -468,9 +468,9 @@ struct buffer_store_if<2>
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
@@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
|
||||
}
|
||||
|
||||
// buffer load i8
|
||||
CK_TILE_DEVICE int8_t
|
||||
CK_TILE_DEVICE_EXTERN int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
CK_TILE_DEVICE int8x2_t
|
||||
CK_TILE_DEVICE_EXTERN int8x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
CK_TILE_DEVICE int8x4_t
|
||||
CK_TILE_DEVICE_EXTERN int8x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
// buffer load i16
|
||||
CK_TILE_DEVICE int16_t
|
||||
CK_TILE_DEVICE_EXTERN int16_t
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
||||
|
||||
CK_TILE_DEVICE int16x2_t
|
||||
CK_TILE_DEVICE_EXTERN int16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
||||
|
||||
CK_TILE_DEVICE int16x4_t
|
||||
CK_TILE_DEVICE_EXTERN int16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
||||
|
||||
// buffer load i32
|
||||
CK_TILE_DEVICE int32_t
|
||||
CK_TILE_DEVICE_EXTERN int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
|
||||
CK_TILE_DEVICE int32x2_t
|
||||
CK_TILE_DEVICE_EXTERN int32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
CK_TILE_DEVICE int32x4_t
|
||||
CK_TILE_DEVICE_EXTERN int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
// buffer load fp16
|
||||
CK_TILE_DEVICE fp16_t
|
||||
CK_TILE_DEVICE_EXTERN _Float16
|
||||
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
CK_TILE_DEVICE fp16x4_t
|
||||
CK_TILE_DEVICE_EXTERN fp16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// buffer load fp32
|
||||
CK_TILE_DEVICE float
|
||||
CK_TILE_DEVICE_EXTERN float
|
||||
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
|
||||
CK_TILE_DEVICE fp32x2_t
|
||||
CK_TILE_DEVICE_EXTERN fp32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
|
||||
CK_TILE_DEVICE fp32x4_t
|
||||
CK_TILE_DEVICE_EXTERN fp32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
|
||||
// buffer store i8
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
// buffer store i16
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
// buffer store i32
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// buffer store fp16
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
|
||||
// buffer store fp32
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
// buffer atomic-add fp16
|
||||
CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer atomic-max fp64
|
||||
CK_TILE_DEVICE double
|
||||
CK_TILE_DEVICE_EXTERN double
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
@@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<fp16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<bf16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<bf16x4_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(
|
||||
src_thread_data.template get_as<bf16x4_t>()[number<0>{}],
|
||||
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(
|
||||
src_thread_data.template get_as<bf16x4_t>()[number<1>{}],
|
||||
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(bf16_t),
|
||||
@@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
|
||||
@@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop()
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\
|
||||
s_nop 0 \n \
|
||||
" ::);
|
||||
#else
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,13 +9,15 @@
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST __host__
|
||||
#define CK_TILE_DEVICE __device__
|
||||
#define CK_TILE_HOST_DEVICE __host__ __device__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
@@ -122,7 +124,7 @@
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
@@ -132,3 +134,7 @@
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
@@ -21,7 +21,12 @@ struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
// TODO: do we need this?
|
||||
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
|
||||
// union {
|
||||
value_type data[N];
|
||||
// bulk_type __content;
|
||||
//};
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
@@ -44,18 +49,24 @@ struct array
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
|
||||
template <typename Y>
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = c;
|
||||
}
|
||||
template <typename ArrayType>
|
||||
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
|
||||
{
|
||||
static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = o.data[i];
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
@@ -147,10 +158,10 @@ struct vector_traits<array<T, N>>
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
|
||||
{
|
||||
using value_type = remove_cvref_t<T>;
|
||||
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
|
||||
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
|
||||
@@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
@@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
}()
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
result = f(result, Seq::get(i));
|
||||
result = f(result, Seq::at(i));
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
flag = flag || f(Seq::get(i));
|
||||
flag = flag || f(Seq::at(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
@@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
flag = flag && f(Seq::get(i));
|
||||
flag = flag && f(Seq::at(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
@@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// template <index_t... Is>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple<number<Is>...>)
|
||||
// {
|
||||
// return sequence<Is...>{};
|
||||
// }
|
||||
template <class... T>
|
||||
struct tuple;
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
|
||||
{
|
||||
return sequence<Is...>{};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
|
||||
|
||||
@@ -139,6 +139,26 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
}
|
||||
});
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
@@ -237,21 +257,21 @@ template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
@@ -490,58 +510,58 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
@@ -4,44 +4,36 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator==(const type_& x, const type_& y) \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator!=(const type_& x, const type_& y) \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<(const type_& x, const type_& y) \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<=(const type_& x, const type_& y) \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>(const type_& x, const type_& y) \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>=(const type_& x, const type_& y) \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator+(const type_& x, const type_& y) \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x) \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
@@ -49,66 +41,55 @@
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x, const type_& y) \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator*(const type_& x, const type_& y) \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator/(const type_& x, const type_& y) \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator+=(type_& x, const type_& y) \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator-=(type_& x, const type_& y) \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator*=(type_& x, const type_& y) \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator/=(type_& x, const type_& y) \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator++(type_& x) \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator--(type_& x) \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator++(type_& x, int) \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator--(type_& x, int) \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
|
||||
@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
|
||||
{
|
||||
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
return __builtin_bit_cast(
|
||||
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
return __builtin_bit_cast(Y,
|
||||
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
|
||||
(out_exponent << out_mant) | mantissa));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
|
||||
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
if(x_raw == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
uint32_t sign = x_raw >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
|
||||
int exponent = (x_raw & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval = x_raw;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
@@ -46,6 +53,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
@@ -61,6 +72,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t double_to_fp16(const double& x) { return half_t{x}; }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
|
||||
};
|
||||
|
||||
// arithmetic
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
@@ -14,8 +14,9 @@ struct constant
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - clz(x);
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
@@ -275,7 +276,7 @@ struct log2e<float>
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
inline constexpr T log2e_v = log2e<T>::value;
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -298,16 +299,32 @@ bool isnan(const float& x)
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
|
||||
@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bfp16
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
|
||||
@@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
using InVecType = typename InVec::type;
|
||||
using OutVecType = typename OutVec::type;
|
||||
// using InVec = typename InVec::type;
|
||||
// using OutVec = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
@@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVecType>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
});
|
||||
|
||||
// transpose
|
||||
@@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVecType>(
|
||||
number<out_offset / sizeof(OutVecType)>{},
|
||||
out_vectors[i].template get_as<OutVecType>()[I0]);
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
|
||||
@@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_length();
|
||||
return Base::get_top_dimension_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
|
||||
@@ -296,7 +296,8 @@ CK_TILE_HOST_DEVICE constexpr auto
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
constexpr auto h_minor_lengths =
|
||||
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
@@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
@@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
@@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
@@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
|
||||
@@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
@@ -69,6 +48,36 @@ struct nonesuch
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
using has_is_static = decltype(T::is_static());
|
||||
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_is_static, T>{})
|
||||
return T::is_static();
|
||||
else
|
||||
return std::is_arithmetic<T>::value;
|
||||
}();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
|
||||
Reference in New Issue
Block a user