mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge remote-tracking branch 'origin/feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
This commit is contained in:
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
|
||||
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -383,6 +383,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(num_splits != 1)
|
||||
{
|
||||
std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
|
||||
num_splits = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
|
||||
} // namespace impl
|
||||
|
||||
// TODO: glc/slc/...
|
||||
template <index_t bytes>
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct buffer_load;
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
|
||||
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
|
||||
// (exp_vector_type(xxx))
|
||||
template <>
|
||||
struct buffer_load<16>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<16, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load<8>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<8, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load<4>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<4, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load<2>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load<1>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<1, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t bytes>
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct buffer_load_if;
|
||||
|
||||
template <>
|
||||
struct buffer_load_if<16>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<16, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load_if<8>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<8, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load_if<4>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<4, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load_if<2>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct buffer_load_if<1>
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<1, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
|
||||
@@ -294,17 +379,16 @@ struct buffer_store<16>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = fp32x4_t;
|
||||
asm volatile(
|
||||
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -315,17 +399,16 @@ struct buffer_store<8>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = fp32x2_t;
|
||||
asm volatile(
|
||||
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -336,17 +419,16 @@ struct buffer_store<4>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile(
|
||||
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -357,17 +439,16 @@ struct buffer_store<2>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
using mbuf_t = short;
|
||||
asm volatile(
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -378,17 +459,16 @@ struct buffer_store<1>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile(
|
||||
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -402,21 +482,20 @@ struct buffer_store_if<16>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)),
|
||||
"v"(v_offset),
|
||||
"s"(res),
|
||||
"s"(s_offset),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
@@ -431,7 +510,7 @@ struct buffer_store_if<8>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
@@ -439,14 +518,13 @@ struct buffer_store_if<8>
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
|
||||
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)),
|
||||
"v"(v_offset),
|
||||
"s"(res),
|
||||
"s"(s_offset),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
@@ -461,21 +539,20 @@ struct buffer_store_if<4>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)),
|
||||
"v"(v_offset),
|
||||
"s"(res),
|
||||
"s"(s_offset),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
@@ -490,21 +567,20 @@ struct buffer_store_if<2>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)),
|
||||
"v"(v_offset),
|
||||
"s"(res),
|
||||
"s"(s_offset),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
@@ -519,21 +595,20 @@ struct buffer_store_if<1>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)),
|
||||
"v"(v_offset),
|
||||
"s"(res),
|
||||
"s"(s_offset),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
CK_TILE_DEVICE void async_buffer_load_dword(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0)
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t /*soffset*/,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset)
|
||||
: "memory");
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
: "memory");
|
||||
else
|
||||
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
|
||||
@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t flag = 0)
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
|
||||
@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
using type = thread_buffer<T, N>;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
buffer_load_if<sizeof(type)>{}(
|
||||
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
|
||||
buffer_load_if<sizeof(type), pre_nop>{}(dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_load<sizeof(type)>{}(
|
||||
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
|
||||
buffer_load<sizeof(type), pre_nop>{}(dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0)
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
async_buffer_load_dword(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset);
|
||||
async_buffer_load_dword_v(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_element_space_size,
|
||||
index_t is_valid_element = 0)
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>(
|
||||
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element);
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// unfortunately async copy can not make sure invalid data is zero inside LDS
|
||||
@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
// buffer_load OOB still working.
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_element_space_size)
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
|
||||
@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
" ::);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop()
|
||||
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\
|
||||
s_nop 0 \n \
|
||||
" ::);
|
||||
asm volatile("s_nop %0" : : "n"(cnt) :);
|
||||
#else
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_sched_barrier(cnt);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
#include "hip/hip_version.h"
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
@@ -147,6 +148,14 @@
|
||||
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() {}
|
||||
|
||||
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
|
||||
{
|
||||
return address_space_enum::generic;
|
||||
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
T* p_data_ = nullptr;
|
||||
BufferSizeType buffer_size_;
|
||||
int32x4_t cached_buf_res_;
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view()
|
||||
: p_data_{}, buffer_size_{}, invalid_element_value_{}
|
||||
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
// this is non constexpr intentially (will call some intrinsic internally)
|
||||
// Must call for buffers that need *_raw load/store
|
||||
CK_TILE_HOST_DEVICE void init_raw()
|
||||
{
|
||||
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
|
||||
{
|
||||
return address_space_enum::global;
|
||||
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
|
||||
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
|
||||
index_t i,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
|
||||
dst, p_data_, i, buffer_size_, is_valid_element);
|
||||
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
|
||||
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
|
||||
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
|
||||
index_t i,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
// X is vector of T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem, p_data_, i, buffer_size_);
|
||||
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() {}
|
||||
|
||||
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
|
||||
{
|
||||
return address_space_enum::lds;
|
||||
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() {}
|
||||
|
||||
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
|
||||
{
|
||||
return address_space_enum::vgpr;
|
||||
|
||||
@@ -36,30 +36,37 @@ template <typename T,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true>
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord>
|
||||
index_t NumCoord,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto
|
||||
async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load(lds_tile);
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
|
||||
@@ -35,6 +35,8 @@ struct null_tile_window
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
|
||||
|
||||
CK_TILE_DEVICE void init_raw() {}
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
};
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ struct tensor_view
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
@@ -85,30 +87,34 @@ struct tensor_view
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check>(
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord) const
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
|
||||
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
|
||||
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
|
||||
|
||||
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
|
||||
// sub-dword tensor...
|
||||
template <typename DstrTensors, index_t v>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
|
||||
template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
|
||||
CK_TILE_DEVICE void
|
||||
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
|
||||
{
|
||||
constexpr index_t tensor_bytes =
|
||||
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0)
|
||||
using elem_type = typename DstrTensors::DataType;
|
||||
constexpr index_t elem_size = sizeof(elem_type);
|
||||
|
||||
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
|
||||
|
||||
// # bytes per write = 4
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
|
||||
{
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
auto& buffer = dstr_tensor.get_thread_buffer();
|
||||
|
||||
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
|
||||
if constexpr(elem_size == 1)
|
||||
{
|
||||
// # elements per write = 4
|
||||
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
|
||||
|
||||
buffer[i_write * 4 + 0] = values.x;
|
||||
buffer[i_write * 4 + 1] = values.y;
|
||||
buffer[i_write * 4 + 2] = values.z;
|
||||
buffer[i_write * 4 + 3] = values.w;
|
||||
}
|
||||
else if constexpr(elem_size == 2)
|
||||
{
|
||||
// # elements per write = 2
|
||||
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
|
||||
|
||||
buffer[i_write * 2 + 0] = values.x;
|
||||
buffer[i_write * 2 + 1] = values.y;
|
||||
}
|
||||
else if constexpr(elem_size == 4)
|
||||
{
|
||||
// # elements per write = 1
|
||||
constexpr elem_type value = 0;
|
||||
|
||||
buffer[i_write] = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "type not supported");
|
||||
}
|
||||
});
|
||||
#else
|
||||
using dvec_t = array<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
|
||||
dstr_tensor);
|
||||
tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
|
||||
dstr_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, bool oob_conditional_check = true>
|
||||
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
bool_constant<oob_conditional_check>{},
|
||||
pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
});
|
||||
});
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
|
||||
"scratch memory" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
|
||||
@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord);
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
|
||||
@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
q_dram_window.init_raw();
|
||||
|
||||
// TODO: we use async Copy for K, which is inline asm
|
||||
// a side effect is we have to use inline asm for q as well
|
||||
@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dram_window.init_raw();
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = [&]() {
|
||||
if constexpr(kPadSeqLenK &&
|
||||
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window);
|
||||
k_dram_window,
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window =
|
||||
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail
|
||||
|
||||
@@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
first = false;
|
||||
else
|
||||
os << delim;
|
||||
os << static_cast<T>(v);
|
||||
|
||||
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
|
||||
{
|
||||
os << ck::type_convert<float>(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
os << static_cast<T>(v);
|
||||
}
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename OutElementOp,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
auto pass = true; // return status
|
||||
|
||||
using CShuffleDataType = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
Tensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<CShuffleDataType> c(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-1, 1});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
|
||||
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
|
||||
// random scale values
|
||||
auto scale_in = type_convert<float>(
|
||||
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
|
||||
auto scale_wei = type_convert<float>(
|
||||
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
|
||||
auto scale_out = type_convert<float>(
|
||||
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
|
||||
|
||||
// initialize out_element_op for each iteration
|
||||
const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out};
|
||||
|
||||
std::cout << "scale_in: " << scale_in << std::endl;
|
||||
std::cout << "scale_wei: " << scale_wei << std::endl;
|
||||
std::cout << "scale_out: " << scale_out << std::endl;
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
std::cout << "\nVerifying algorithm against reference convolution..." << std::endl;
|
||||
std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7)
|
||||
<< get_rtol<OutDataType>() << ", " << get_atol<OutDataType>() << ")" << std::endl;
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
CShuffleDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
PassThrough>{};
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weight,
|
||||
c,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
PassThrough{});
|
||||
|
||||
c.SetZero();
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); });
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
out_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(device_output,
|
||||
host_output,
|
||||
"Error: Device and Host results do not match!",
|
||||
get_rtol<OutDataType>(),
|
||||
get_atol<OutDataType>());
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<InDataType>(std::cout << "input : ", input.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<WeiDataType>(std::cout << "weight: ", weight.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<OutDataType>(
|
||||
std::cout << "host_output : ", host_output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<OutDataType>(
|
||||
std::cout << "device_output: ", device_output.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
{},
|
||||
{},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
run_impl(op_ptr, argument_ptr);
|
||||
}
|
||||
|
||||
std::cout << "Best configuration parameters:"
|
||||
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -57,6 +57,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -134,6 +135,8 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
|
||||
220
profiler/src/profile_grouped_conv_fwd_outelementop.cpp
Normal file
220
profiler/src/profile_grouped_conv_fwd_outelementop.cpp
Normal file
@@ -0,0 +1,220 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
enum struct ConvLayout
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK = 0,
|
||||
NHWGC_GKYXC_NHWGK = 1
|
||||
};
|
||||
|
||||
enum struct OutElementOp
|
||||
{
|
||||
ConvScale = 0,
|
||||
ConvInvScale = 1
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F8_F8_F8 = 0,
|
||||
BF8_BF8_F8 = 1,
|
||||
F8_BF8_F8 = 2,
|
||||
BF8_F8_F8 = 3
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_fwd_outelementop"
|
||||
#define OP_DESC "Grouped Convolution Forward+Elementwise Operation"
|
||||
|
||||
static void print_helper_msg()
|
||||
{
|
||||
// clang-format off
|
||||
std::cout
|
||||
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: Input fp8, Weight fp8, Output fp8\n"
|
||||
<< " 1: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 2: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 3: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< "arg3: element-wise operation (0: ConvScale\n"
|
||||
<< " 1: ConvInvScale)\n"
|
||||
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg5: verification (0: no, 1: yes)\n"
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg8: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
int grouped_conv_fwd_outelementop(int argc, char* argv[])
|
||||
{
|
||||
|
||||
// 9 total, 1 for num_dim_spatial
|
||||
if(argc < 10)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto op = static_cast<OutElementOp>(std::stoi(argv[3]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[4]));
|
||||
const bool do_verification = std::stoi(argv[5]);
|
||||
const int init_method = std::stoi(argv[6]);
|
||||
const bool do_log = std::stoi(argv[7]);
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0]
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
|
||||
using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale;
|
||||
|
||||
constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
auto profile = [&](auto num_dim_spatial_tmp,
|
||||
auto in_layout,
|
||||
auto wei_layout,
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto out_element_op,
|
||||
auto a_compute_type,
|
||||
auto b_compute_type) {
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
using InDataType = decltype(in_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using OutElementOp = decltype(out_element_op);
|
||||
|
||||
using AComputeType = decltype(a_compute_type);
|
||||
using BComputeType = decltype(b_compute_type);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_fwd_outelementop_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
OutElementOp,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(op == OutElementOp::ConvScale)
|
||||
{
|
||||
if(data_type == ConvDataType::F8_F8_F8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvScale{}, F8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_BF8_F8)
|
||||
{
|
||||
return profile(I3,
|
||||
NDHWGC{},
|
||||
GKZYXC{},
|
||||
NDHWGK{},
|
||||
BF8{},
|
||||
BF8{},
|
||||
F8{},
|
||||
ConvScale{},
|
||||
BF8{},
|
||||
BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_BF8_F8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, ConvScale{}, F8{}, BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_F8_F8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, ConvScale{}, BF8{}, F8{});
|
||||
}
|
||||
}
|
||||
else if(op == OutElementOp::ConvInvScale)
|
||||
{
|
||||
if(data_type == ConvDataType::F8_F8_F8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_BF8_F8)
|
||||
{
|
||||
return profile(I3,
|
||||
NDHWGC{},
|
||||
GKZYXC{},
|
||||
NDHWGK{},
|
||||
BF8{},
|
||||
BF8{},
|
||||
F8{},
|
||||
ConvInvScale{},
|
||||
BF8{},
|
||||
BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_BF8_F8)
|
||||
{
|
||||
return profile(I3,
|
||||
NDHWGC{},
|
||||
GKZYXC{},
|
||||
NDHWGK{},
|
||||
F8{},
|
||||
BF8{},
|
||||
F8{},
|
||||
ConvInvScale{},
|
||||
F8{},
|
||||
BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_F8_F8)
|
||||
{
|
||||
return profile(I3,
|
||||
NDHWGC{},
|
||||
GKZYXC{},
|
||||
NDHWGK{},
|
||||
BF8{},
|
||||
F8{},
|
||||
F8{},
|
||||
ConvInvScale{},
|
||||
BF8{},
|
||||
F8{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_outelementop);
|
||||
20
script/profile_grouped_conv_fwd_outelementop.sh
Executable file
20
script/profile_grouped_conv_fwd_outelementop.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
DRIVER="../build/bin/ckProfiler"
|
||||
|
||||
OP=$1
|
||||
DATATYPE=$2
|
||||
OUTELEMENTOP=$3
|
||||
LAYOUT=$4
|
||||
VERIFY=$5
|
||||
INIT=$6
|
||||
LOG=$7
|
||||
TIME=$8
|
||||
|
||||
N=$9
|
||||
|
||||
####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx
|
||||
$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
Reference in New Issue
Block a user