mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
template<index_t N, typename T> struct smem_load_trait;
|
||||
|
||||
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
|
||||
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
|
||||
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
|
||||
template <index_t>
|
||||
struct smem_load;
|
||||
|
||||
template <>
|
||||
struct smem_load<16>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
|
||||
asm volatile("ds_read_b128 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<8>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
|
||||
asm volatile("ds_read_b64 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<4>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
|
||||
asm volatile("ds_read_b32 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<2>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u16 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<1>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u8 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl{
|
||||
|
||||
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_linear_addr_offset,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset,
|
||||
index_t dst_linear_addr_offset,
|
||||
index_t is_valid_element = 1)
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
is_valid_element);
|
||||
}
|
||||
else
|
||||
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
dst_linear_addr_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2014,6 +2156,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2041,16 +2186,19 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2066,6 +2214,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
@@ -2086,12 +2240,42 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
@@ -2146,6 +2330,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const index_t dst_linear_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
dst_thread_element_valid);
|
||||
}
|
||||
|
||||
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
|
||||
Reference in New Issue
Block a user