now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -336,8 +336,8 @@ struct buffer_store<2>
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
@@ -468,9 +468,9 @@ struct buffer_store_if<2>
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
@@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
}
// buffer load i8
CK_TILE_DEVICE int8_t
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE int8x2_t
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE int8x4_t
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE int16_t
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE int16x2_t
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE int16x4_t
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE int32_t
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE int32x2_t
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE int32x4_t
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE fp16_t
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE fp16x2_t
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE fp16x4_t
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE float
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE fp32x2_t
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE fp32x4_t
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
@@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE double
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
@@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<fp16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<bf16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<bf16x2_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<bf16x4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<0>{}],
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<1>{}],
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bf16_t),
@@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
}
// Direct loads from global to LDS.
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,

View File

@@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
CK_TILE_DEVICE void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck_tile

View File

@@ -9,6 +9,9 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace ck_tile {
@@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
} // namespace ck_tile

View File

@@ -9,13 +9,15 @@
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST __host__
#define CK_TILE_DEVICE __device__
#define CK_TILE_HOST_DEVICE __host__ __device__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
@@ -122,7 +124,7 @@
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
@@ -132,3 +134,7 @@
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif

View File

@@ -21,7 +21,12 @@ struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
@@ -44,18 +49,24 @@ struct array
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
template <typename Y>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = c;
}
template <typename ArrayType>
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
{
static_assert(ArrayType::size() == size(), "wrong! size not the same");
for(auto i = 0; i < size(); i++)
data[i] = o.data[i];
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
@@ -147,10 +158,10 @@ struct vector_traits<array<T, N>>
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
}
// make empty array

View File

@@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
@@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
}()
#endif
} // namespace ck_tile

View File

@@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
for(index_t i = 0; i < Seq::size(); ++i)
{
result = f(result, Seq::get(i));
result = f(result, Seq::at(i));
}
return result;
@@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag || f(Seq::get(i));
flag = flag || f(Seq::at(i));
}
return flag;
@@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag && f(Seq::get(i));
flag = flag && f(Seq::at(i));
}
return flag;
@@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// template <index_t... Is>
// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple<number<Is>...>)
// {
// return sequence<Is...>{};
// }
template <class... T>
struct tuple;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
{
return sequence<Is...>{};
}
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>

View File

@@ -139,6 +139,26 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
return !(a == b);
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
@@ -237,21 +257,21 @@ template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
@@ -490,58 +510,58 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) \
_Pragma("clang diagnostic pop")
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif

View File

@@ -4,44 +4,36 @@
#pragma once
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
CK_TILE_HOST_DEVICE \
bool operator==(const type_& x, const type_& y) \
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator!=(const type_& x, const type_& y) \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<(const type_& x, const type_& y) \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<=(const type_& x, const type_& y) \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>(const type_& x, const type_& y) \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>=(const type_& x, const type_& y) \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
type_ operator+(const type_& x, const type_& y) \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x) \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
@@ -49,66 +41,55 @@
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x, const type_& y) \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator*(const type_& x, const type_& y) \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator/(const type_& x, const type_& y) \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_& operator+=(type_& x, const type_& y) \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator-=(type_& x, const type_& y) \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator*=(type_& x, const type_& y) \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator/=(type_& x, const type_& y) \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator++(type_& x) \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator--(type_& x) \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_ operator++(type_& x, int) \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator--(type_& x, int) \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \

View File

@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x);
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
return float_to_bf16_truc_raw(f);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
{
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
return u.fp32;
}
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
{
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
}
};
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
// convert to bitwise
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
// check if x is 0.0
if(x_bitwise == 0)
return 0;
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
}
else
{
return signed_inf;
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
return __builtin_bit_cast(
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
}
template <typename X, typename Y, bool negative_zero_nan>
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr X nan_code = 0x80;
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
if(x_raw == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
uint32_t sign = x_raw >> (in_exp + in_mant);
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
int exponent = (x_raw & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if constexpr(negative_zero_nan)
{
if(x == nan_code)
if(x_raw == nan_code)
return NaN;
}
else
{
if(x == nan_code)
if(x_raw == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval = x_raw;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
};
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include <hip/hip_fp16.h>
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x);
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
@@ -46,6 +53,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
@@ -61,6 +72,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to double
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x)
{
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
CK_TILE_HOST_DEVICE
half_t double_to_fp16(const double& x) { return half_t{x}; }
// limits
template <class T>
struct numeric_limits;
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
};
// arithmetic
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
return y;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
return y;
}
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }

View File

@@ -14,8 +14,9 @@ struct constant
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <typename T, T v>

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace ck_tile {
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - clz(x);
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
@@ -275,7 +276,7 @@ struct log2e<float>
};
template <typename T = double>
inline constexpr T log2e_v = log2e<T>::value;
constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
@@ -298,16 +299,32 @@ bool isnan(const float& x)
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
} // namespace ck_tile

View File

@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)

View File

@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bfp16
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));

View File

@@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
@@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVecType>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVec>(
number<in_offset / vec_length_in>{});
});
// transpose
@@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVecType>(
number<out_offset / sizeof(OutVecType)>{},
out_vectors[i].template get_as<OutVecType>()[I0]);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}

View File

@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \

View File

@@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_length();
return Base::get_top_dimension_lengths();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const

View File

@@ -296,7 +296,8 @@ CK_TILE_HOST_DEVICE constexpr auto
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
constexpr auto h_minor_lengths =
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();
@@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
@@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
@@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
@@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"

View File

@@ -7,14 +7,14 @@
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else

View File

@@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace impl {
template <typename T>
struct is_static_impl
{
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
@@ -69,6 +48,36 @@ struct nonesuch
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
namespace impl {
template <typename T>
using has_is_static = decltype(T::is_static());
template <typename T>
struct is_static_impl
{
static constexpr bool value = []() {
if constexpr(is_detected<has_is_static, T>{})
return T::is_static();
else
return std::is_arithmetic<T>::value;
}();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template <
typename PY,

View File

@@ -40,7 +40,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -98,7 +98,7 @@ template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -157,7 +157,7 @@ template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -182,7 +182,7 @@ check_err(const Range& out,
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
@@ -220,11 +220,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
@@ -270,12 +270,12 @@ template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
@@ -323,12 +323,12 @@ template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{

View File

@@ -3,13 +3,14 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <sstream>
#include <stdexcept>
#include <hip/hip_runtime.h>
namespace ck_tile {
// To be removed, which really does not tell the location of failed HIP functional call
inline void hip_check_error(hipError_t x)
CK_TILE_HOST void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
{

View File

@@ -18,11 +18,11 @@
namespace ck_tile {
template <typename Range>
std::ostream& LogRange(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
CK_TILE_HOST std::ostream& LogRange(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{
bool first = true;
for(auto&& v : range)
@@ -37,11 +37,11 @@ std::ostream& LogRange(std::ostream& os,
}
template <typename T, typename Range>
std::ostream& LogRangeAsType(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
CK_TILE_HOST std::ostream& LogRangeAsType(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{
bool first = true;
for(auto&& v : range)
@@ -56,13 +56,13 @@ std::ostream& LogRangeAsType(std::ostream& os,
}
template <typename F, typename T, std::size_t... Is>
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
{
return f(std::get<Is>(args)...);
}
template <typename F, typename T>
auto call_f_unpack_args(F f, T args)
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
@@ -70,13 +70,13 @@ auto call_f_unpack_args(F f, T args)
}
template <typename F, typename T, std::size_t... Is>
auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
{
return F(std::get<Is>(args)...);
}
template <typename F, typename T>
auto construct_f_unpack_args(F, T args)
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
@@ -87,7 +87,19 @@ struct HostTensorDescriptor
{
HostTensorDescriptor() = default;
void CalculateStrides();
void CalculateStrides()
{
mStrides.clear();
mStrides.resize(mLens.size(), 0);
if(mStrides.empty())
return;
mStrides.back() = 1;
std::partial_sum(mLens.rbegin(),
mLens.rend() - 1,
mStrides.rbegin() + 1,
std::multiplies<std::size_t>());
}
template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
@@ -123,12 +135,28 @@ struct HostTensorDescriptor
{
}
std::size_t get_num_of_dimension() const;
std::size_t get_element_size() const;
std::size_t get_element_space_size() const;
std::size_t get_num_of_dimension() const { return mLens.size(); }
std::size_t get_element_size() const
{
assert(mLens.size() == mStrides.size());
return std::accumulate(
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t get_element_space_size() const
{
std::size_t space = 1;
for(std::size_t i = 0; i < mLens.size(); ++i)
{
if(mLens[i] == 0)
continue;
const std::vector<std::size_t>& get_lengths() const;
const std::vector<std::size_t>& GetStrides() const;
space += (mLens[i] - 1) * mStrides[i];
}
return space;
}
const std::vector<std::size_t>& get_lengths() const { return mLens; }
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
@@ -151,8 +179,8 @@ struct HostTensorDescriptor
};
template <typename New2Old>
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
const New2Old& new2old)
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(
const HostTensorDescriptor& a, const New2Old& new2old)
{
std::vector<std::size_t> new_lengths(a.get_num_of_dimension());
std::vector<std::size_t> new_strides(a.get_num_of_dimension());
@@ -238,7 +266,7 @@ struct ParallelTensorFunctor
};
template <typename F, typename... Xs>
auto make_ParallelTensorFunctor(F f, Xs... xs)
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
{
return ParallelTensorFunctor<F, Xs...>(f, xs...);
}

View File

@@ -20,12 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
}
template <typename... Args, typename F>
float launch_and_time_kernel(const stream_config& s,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
@@ -75,13 +75,13 @@ float launch_and_time_kernel(const stream_config& s,
}
template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const stream_config& s,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
@@ -151,12 +151,12 @@ template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
CK_TILE_HOST float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;

View File

@@ -10,7 +10,6 @@
// ranges implementation are not intented to be used by user
// TODO: do we need this?
namespace ck_tile {
namespace ranges {
template <typename T>
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
@@ -21,8 +20,7 @@ using iter_reference_t = decltype(*std::declval<T&>());
template <typename T>
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
//.........................
namespace ranges {
template <typename R>
using iterator_t = decltype(std::begin(std::declval<R&>()));

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename BinaryElementOp = ck_tile::plus<AccDataType>>
void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
{
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename CDataType, typename MaskingType>
void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
{
const int M = c_b_m_n.mDesc.get_lengths()[1];
const int N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename CompDataType, typename BDataType>
void reference_batched_softmax(
CK_TILE_HOST void reference_batched_softmax(
const HostTensor<ADataType>& a_b_m_n,
HostTensor<BDataType>& b_b_m_n,
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];

View File

@@ -10,25 +10,25 @@
namespace ck_tile {
template <typename T>
void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
{
int GemmM = in_mtx_host_ref.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];

View File

@@ -10,12 +10,13 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_softmax(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m_n)
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
AccDataType v_max = ck_tile::NumericLimits<ADataType>::Lowest();
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)

View File

@@ -575,9 +575,8 @@ struct FmhaFwdKernel
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables is not supported. Remove
/// following copy capture of the 'i_nhead'
/// if compiled in C++20
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -189,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -208,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -346,12 +347,15 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -360,7 +364,7 @@ struct BlockFmhaPipelineQRKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -375,7 +379,7 @@ struct BlockFmhaPipelineQRKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -231,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
@@ -251,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -389,12 +390,15 @@ struct BlockFmhaPipelineQRKSVSAsync
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -403,7 +407,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -454,7 +458,7 @@ struct BlockFmhaPipelineQRKSVSAsync
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -181,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -329,12 +330,15 @@ struct BlockFmhaPipelineQRKSVSFp8
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -343,7 +347,7 @@ struct BlockFmhaPipelineQRKSVSFp8
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -358,7 +362,7 @@ struct BlockFmhaPipelineQRKSVSFp8
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -338,12 +338,15 @@ struct BlockFmhaPipelineQSKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -352,7 +355,7 @@ struct BlockFmhaPipelineQSKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -367,7 +370,7 @@ struct BlockFmhaPipelineQSKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -9,6 +9,11 @@
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
@@ -97,9 +102,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -222,9 +226,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -918,12 +921,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
{
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {