mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
topk_softmax (#1592)
* topk_softmax
* remove some file
* fix atomix linear_offset
* address various comment, and change sfc get_index api to static(tuple)
[ROCm/composable_kernel commit: b098b71b05]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -81,8 +81,10 @@ struct space_filling_curve
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
// Do not use this function directly!
|
||||
// TODO: can refactor into generic lambda in the future
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
|
||||
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
@@ -153,11 +155,11 @@ struct space_filling_curve
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: rename this function
|
||||
// FIXME: return tuple of number<>, which is compile time only variable
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = get_index(number<AccessIdx1d>{});
|
||||
constexpr auto idx = _get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
|
||||
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
template<index_t N, typename T> struct smem_load_trait;
|
||||
|
||||
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
|
||||
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
|
||||
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
|
||||
template <index_t>
|
||||
struct smem_load;
|
||||
|
||||
template <>
|
||||
struct smem_load<16>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
|
||||
asm volatile("ds_read_b128 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<8>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
|
||||
asm volatile("ds_read_b64 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<4>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
|
||||
asm volatile("ds_read_b32 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<2>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u16 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<1>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u8 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl{
|
||||
|
||||
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_linear_addr_offset,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset,
|
||||
index_t dst_linear_addr_offset,
|
||||
index_t is_valid_element = 1)
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
is_valid_element);
|
||||
}
|
||||
else
|
||||
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
dst_linear_addr_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2014,6 +2156,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2041,16 +2186,19 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2066,6 +2214,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
@@ -2086,12 +2240,42 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
@@ -2146,6 +2330,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const index_t dst_linear_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
dst_thread_element_valid);
|
||||
}
|
||||
|
||||
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
|
||||
@@ -41,6 +41,19 @@
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
// implementing the "memory address space" attribute
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
|
||||
#ifdef __HIPCC_
|
||||
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
|
||||
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
|
||||
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
|
||||
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
|
||||
#else
|
||||
#define CK_TILE_GENERIC_ADDR
|
||||
#define CK_TILE_GLOBAL_ADDR
|
||||
#define CK_TILE_LDS_ADDR
|
||||
#define CK_TILE_BUF_RES_ADDR
|
||||
#endif
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
@@ -205,3 +218,8 @@
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
|
||||
#define CK_TILE_WORKAROUND_SWDEV_383542 1
|
||||
#endif
|
||||
|
||||
@@ -623,7 +623,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
@@ -635,7 +635,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
@@ -647,7 +647,7 @@ template <typename... Xs,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
|
||||
@@ -487,55 +487,12 @@ struct log2e<float>
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
float abs(const float& x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = bit_cast<uint32_t>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __ocml_exp_f32(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
|
||||
{
|
||||
return __builtin_amdgcn_sad_u16(x, y, acc);
|
||||
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
|
||||
return (x > y ? (x - y) : (y - x)) + acc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace ck_tile
|
||||
// blow function need data type pre-defined
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#ifndef __HIP_DEVICE_COMPILE__
|
||||
#include <cmath>
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
#if CK_TILE_WORKAROUND_SWDEV_383542
|
||||
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
CK_TILE_HOST float abs(float x) { return std::abs(x); };
|
||||
|
||||
CK_TILE_HOST double abs(double x) { return std::abs(x); };
|
||||
|
||||
CK_TILE_HOST int8_t abs(int8_t x)
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_HOST int32_t abs(int32_t x)
|
||||
{
|
||||
int32_t sgn = x >> (32 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_HOST fp16_t abs(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
uint16_t abs_xx = xx & 0x7fff;
|
||||
|
||||
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
|
||||
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_HOST int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
return (x ^ sgn) - sgn;
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
|
||||
|
||||
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
|
||||
|
||||
CK_TILE_HOST bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST bool isnan(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_HOST bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST fp16_t sqrt(fp16_t x)
|
||||
{
|
||||
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T tanh(T x)
|
||||
{
|
||||
return type_convert<T>(std::tanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float tanh<float>(float x)
|
||||
{
|
||||
return std::tanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double tanh<double>(double x)
|
||||
{
|
||||
return std::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T acos(T x)
|
||||
{
|
||||
return type_convert<T>(std::acosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float acos<float>(float x)
|
||||
{
|
||||
return std::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double acos<double>(double x)
|
||||
{
|
||||
return std::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T neg(T x)
|
||||
{
|
||||
return type_convert<T>(-(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T atan(T x)
|
||||
{
|
||||
return type_convert<T>(std::atanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float atan<float>(float x)
|
||||
{
|
||||
return std::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double atan<double>(double x)
|
||||
{
|
||||
return std::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T sin(T x)
|
||||
{
|
||||
return type_convert<T>(std::sinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float sin<float>(float x)
|
||||
{
|
||||
return std::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double sin<double>(double x)
|
||||
{
|
||||
return std::sin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T asin(T x)
|
||||
{
|
||||
return type_convert<T>(std::asinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float asin<float>(float x)
|
||||
{
|
||||
return std::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double asin<double>(double x)
|
||||
{
|
||||
return std::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T asinh(T x)
|
||||
{
|
||||
return type_convert<T>(std::asinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float asinh<float>(float x)
|
||||
{
|
||||
return std::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double asinh<double>(double x)
|
||||
{
|
||||
return std::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T cos(T x)
|
||||
{
|
||||
return type_convert<T>(std::cosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float cos<float>(float x)
|
||||
{
|
||||
return std::cosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double cos<double>(double x)
|
||||
{
|
||||
return std::cos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T acosh(T x)
|
||||
{
|
||||
return type_convert<T>(std::acoshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float acosh<float>(float x)
|
||||
{
|
||||
return std::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double acosh<double>(double x)
|
||||
{
|
||||
return std::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T tan(T x)
|
||||
{
|
||||
return type_convert<T>(std::tanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float tan<float>(float x)
|
||||
{
|
||||
return std::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double tan<double>(double x)
|
||||
{
|
||||
return std::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T atanh(T x)
|
||||
{
|
||||
return type_convert<T>(std::atanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float atanh<float>(float x)
|
||||
{
|
||||
return std::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double atanh<double>(double x)
|
||||
{
|
||||
return std::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T sinh(T x)
|
||||
{
|
||||
return type_convert<T>(std::sinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float sinh<float>(float x)
|
||||
{
|
||||
return std::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double sinh<double>(double x)
|
||||
{
|
||||
return std::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T ceil(T x)
|
||||
{
|
||||
return type_convert<T>(std::ceilf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float ceil<float>(float x)
|
||||
{
|
||||
return std::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double ceil<double>(double x)
|
||||
{
|
||||
return std::ceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T cosh(T x)
|
||||
{
|
||||
return type_convert<T>(std::coshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float cosh<float>(float x)
|
||||
{
|
||||
return std::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double cosh<double>(double x)
|
||||
{
|
||||
return std::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T floor(T x)
|
||||
{
|
||||
return type_convert<T>(std::floorf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float floor<float>(float x)
|
||||
{
|
||||
return std::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double floor<double>(double x)
|
||||
{
|
||||
return std::floor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T rcp(T x)
|
||||
{
|
||||
return type_convert<T>(1.f / type_convert<float>(x));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T exp(T x)
|
||||
{
|
||||
return type_convert<T>(std::expf(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float exp<float>(float x)
|
||||
{
|
||||
return std::expf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double exp<double>(double x)
|
||||
{
|
||||
return std::exp(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T log(T x)
|
||||
{
|
||||
return type_convert<T>(std::logf(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float log<float>(float x)
|
||||
{
|
||||
return std::logf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double log<double>(double x)
|
||||
{
|
||||
return std::log(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T pow(T x, T gamma)
|
||||
{
|
||||
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float pow<float>(float x, float gamma)
|
||||
{
|
||||
return std::powf(x, gamma);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double pow<double>(double x, double gamma)
|
||||
{
|
||||
return std::pow(x, gamma);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T expm1(T x)
|
||||
{
|
||||
return type_convert<T>(std::expm1f(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float expm1<float>(float x)
|
||||
{
|
||||
return std::expm1f(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double expm1<double>(double x)
|
||||
{
|
||||
return std::expm1(x);
|
||||
}
|
||||
|
||||
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
|
||||
|
||||
CK_TILE_DEVICE float abs(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
|
||||
|
||||
CK_TILE_DEVICE int8_t abs(int8_t x)
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE int32_t abs(int32_t x)
|
||||
{
|
||||
int32_t sgn = x >> (32 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_DEVICE int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE fp16_t abs(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
uint16_t abs_xx = xx & 0x7fff;
|
||||
|
||||
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
|
||||
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
|
||||
|
||||
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
|
||||
|
||||
CK_TILE_DEVICE bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_DEVICE bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE bool isnan(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
|
||||
{
|
||||
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh(T x)
|
||||
{
|
||||
return type_convert<T>(::tanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh<float>(float x)
|
||||
{
|
||||
return ::tanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double tanh<double>(double x)
|
||||
{
|
||||
return ::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T acos(T x)
|
||||
{
|
||||
return type_convert<T>(::acosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float acos<float>(float x)
|
||||
{
|
||||
return ::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double acos<double>(double x)
|
||||
{
|
||||
return ::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T neg(T x)
|
||||
{
|
||||
return type_convert<T>(-(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
|
||||
{
|
||||
return __hneg(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T atan(T x)
|
||||
{
|
||||
return type_convert<T>(::atanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float atan<float>(float x)
|
||||
{
|
||||
return ::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double atan<double>(double x)
|
||||
{
|
||||
return ::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T sin(T x)
|
||||
{
|
||||
return type_convert<T>(::sinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float sin<float>(float x)
|
||||
{
|
||||
return ::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double sin<double>(double x)
|
||||
{
|
||||
return ::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hsin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T asin(T x)
|
||||
{
|
||||
return type_convert<T>(::asinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float asin<float>(float x)
|
||||
{
|
||||
return ::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double asin<double>(double x)
|
||||
{
|
||||
return ::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T asinh(T x)
|
||||
{
|
||||
return type_convert<T>(::asinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float asinh<float>(float x)
|
||||
{
|
||||
return ::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double asinh<double>(double x)
|
||||
{
|
||||
return ::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T acosh(T x)
|
||||
{
|
||||
return type_convert<T>(::acoshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float acosh<float>(float x)
|
||||
{
|
||||
return ::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double acosh<double>(double x)
|
||||
{
|
||||
return ::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tan(T x)
|
||||
{
|
||||
return type_convert<T>(::tanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tan<float>(float x)
|
||||
{
|
||||
return ::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double tan<double>(double x)
|
||||
{
|
||||
return ::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T atanh(T x)
|
||||
{
|
||||
return type_convert<T>(::atanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float atanh<float>(float x)
|
||||
{
|
||||
return ::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double atanh<double>(double x)
|
||||
{
|
||||
return ::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T sinh(T x)
|
||||
{
|
||||
return type_convert<T>(::sinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float sinh<float>(float x)
|
||||
{
|
||||
return ::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double sinh<double>(double x)
|
||||
{
|
||||
return ::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T ceil(T x)
|
||||
{
|
||||
return type_convert<T>(::ceilf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float ceil<float>(float x)
|
||||
{
|
||||
return ::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double ceil<double>(double x)
|
||||
{
|
||||
return ::ceil(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T cosh(T x)
|
||||
{
|
||||
return type_convert<T>(::coshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float cosh<float>(float x)
|
||||
{
|
||||
return ::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double cosh<double>(double x)
|
||||
{
|
||||
return ::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T floor(T x)
|
||||
{
|
||||
return type_convert<T>(::floorf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float floor<float>(float x)
|
||||
{
|
||||
return ::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double floor<double>(double x)
|
||||
{
|
||||
return ::floor(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hfloor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T rcp(T x)
|
||||
{
|
||||
#if !CK_TILE_WORKAROUND_SWDEV_383542
|
||||
return __frcp_rn(x);
|
||||
#else
|
||||
// return __ocml_native_recip_f32(x);
|
||||
return __builtin_amdgcn_rcpf(x);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T exp(T x)
|
||||
{
|
||||
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
|
||||
{
|
||||
return hexp(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float exp<float>(float x)
|
||||
{
|
||||
return __ocml_exp_f32(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double exp<double>(double x)
|
||||
{
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
return type_convert<T>(__logf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
|
||||
{
|
||||
return hlog(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float log<float>(float x)
|
||||
{
|
||||
return __logf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double log<double>(double x)
|
||||
{
|
||||
return log(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T pow(T x, T gamma)
|
||||
{
|
||||
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float pow<float>(float x, float gamma)
|
||||
{
|
||||
return powf(x, gamma);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double pow<double>(double x, double gamma)
|
||||
{
|
||||
return pow(x, gamma);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T expm1(T x)
|
||||
{
|
||||
return type_convert<T>(expm1f(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float expm1<float>(float x)
|
||||
{
|
||||
return expm1f(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double expm1<double>(double x)
|
||||
{
|
||||
return expm1(x);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i]);
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check>(
|
||||
p_data_, i, is_valid_element, buffer_size_);
|
||||
p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
|
||||
remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check>(
|
||||
p_data_, i, is_valid_element, buffer_size_, invalid_element_value_);
|
||||
oob_conditional_check>(p_data_,
|
||||
i + linear_offset,
|
||||
is_valid_element,
|
||||
buffer_size_,
|
||||
invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i]);
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -352,7 +359,8 @@ struct buffer_view<address_space_enum::global,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
|
||||
index_t i,
|
||||
index_t v_offset,
|
||||
index_t i_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -366,7 +374,38 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
|
||||
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
|
||||
dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
|
||||
index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X is vector of T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem,
|
||||
cached_buf_res_,
|
||||
i,
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
@@ -378,6 +417,7 @@ struct buffer_view<address_space_enum::global,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
|
||||
index_t i,
|
||||
index_t linear_offset,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -391,7 +431,7 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
|
||||
smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
@@ -401,25 +441,25 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add<X>(i, is_valid_element, x);
|
||||
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
this->template atomic_max<X>(i, is_valid_element, x);
|
||||
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template set<X>(i, is_valid_element, tmp);
|
||||
}
|
||||
@@ -432,7 +472,7 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -453,7 +493,7 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -462,9 +502,9 @@ struct buffer_view<address_space_enum::global,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -477,7 +517,7 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -489,7 +529,7 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
@@ -497,7 +537,8 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void
|
||||
atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
@@ -532,13 +573,13 @@ struct buffer_view<address_space_enum::global,
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -548,7 +589,8 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void
|
||||
atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -572,11 +614,11 @@ struct buffer_view<address_space_enum::global,
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -668,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -684,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
// using buf_t = ushort __attribute__((ext_vector_type(8)));
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
|
||||
return bit_cast<X>(rtn);
|
||||
#endif
|
||||
}
|
||||
@@ -708,6 +752,23 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
|
||||
index_t v_offset,
|
||||
index_t i_offset,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -715,17 +776,17 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -735,7 +796,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -751,6 +812,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
bool constexpr workaround_int8_ds_write_issue = false;
|
||||
#endif
|
||||
|
||||
i += linear_offset; // simplicity
|
||||
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
int8_t>::value &&
|
||||
workaround_int8_ds_write_issue)
|
||||
@@ -952,8 +1014,10 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t /*linear_offset*/,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -995,17 +1059,17 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,7 +1079,7 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -1030,9 +1094,9 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
@@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -46,7 +61,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -66,7 +101,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
|
||||
@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_array(
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor);
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
@@ -87,7 +88,33 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -16,6 +16,24 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* tensor_view
|
||||
* abstract the underneath memory buffer(global, LDS, etc...)
|
||||
* and provide a unified get/set function for access
|
||||
*
|
||||
* For addressing into the buffer we use 2 variable to control:
|
||||
* coord : ND tensor coordinate, will calculate the actual offset inside
|
||||
* linear_offset : 1D offset, will be used in the immediate field of
|
||||
* the buffer instruction to help reduce register usage
|
||||
*
|
||||
* User can use either of the field, or both to indexing into the tensor
|
||||
*
|
||||
* We usually provide 2 set of API for buffer get/set, e.g.
|
||||
* get_vectorized_elements()/get_vectorized_elements_raw()
|
||||
* the former usually will call intrinsic or normal C function, the later
|
||||
* usually will call inline-asm function
|
||||
*
|
||||
*/
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
@@ -49,22 +67,6 @@ struct tensor_view
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template get<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
|
||||
{
|
||||
buf_.template set<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -75,14 +77,34 @@ struct tensor_view
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element, // flag
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(coord.get_offset(),
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -94,12 +116,90 @@ struct tensor_view
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -110,11 +210,15 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
|
||||
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
|
||||
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
@@ -125,11 +229,15 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -140,15 +248,53 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -157,15 +303,36 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Note: this tile window do not support single issue
|
||||
// you need to use tile_window_linear structure for this purpose
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const AdaptorTopIndex& idx_diff_adaptor_top) const
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
@@ -279,10 +284,11 @@ struct tile_window_with_static_distribution
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -308,11 +314,11 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -338,8 +344,9 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -350,8 +357,12 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
template <typename DstTile,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -397,6 +408,7 @@ struct tile_window_with_static_distribution
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
0 /**/,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
pre_nop_);
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
|
||||
@@ -409,23 +421,24 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
|
||||
"scratch memory" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -467,7 +480,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
@@ -482,15 +495,16 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, pre_nop_);
|
||||
smem, bottom_tensor_thread_coord, 0, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -501,8 +515,81 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -515,7 +602,6 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
@@ -530,7 +616,7 @@ struct tile_window_with_static_distribution
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -548,15 +634,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -565,8 +655,9 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
|
||||
template <index_t i_access_unsupport_ = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -591,7 +682,7 @@ struct tile_window_with_static_distribution
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -606,15 +697,16 @@ struct tile_window_with_static_distribution
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
bottom_tensor_thread_coord, 0, vec_value);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -623,8 +715,9 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -650,7 +743,7 @@ struct tile_window_with_static_distribution
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -666,15 +759,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -746,7 +843,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -798,6 +896,27 @@ make_tile_window(const TensorView_& tensor_view,
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
}
|
||||
|
||||
// this version can't be called in a constexpr context
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window_raw(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
w.init_raw();
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -922,6 +1041,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
|
||||
tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution)
|
||||
{
|
||||
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution);
|
||||
w.init_raw();
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
|
||||
|
||||
1082
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
1082
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -59,8 +59,16 @@ struct magic_division32_bit_range
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
@@ -77,9 +85,18 @@ struct magic_division32_bit_range
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
|
||||
Reference in New Issue
Block a user