Merge branch 'develop' into feature/cond-add-splitkv

This commit is contained in:
Po Yen Chen
2024-07-09 03:42:25 +08:00
committed by GitHub
18 changed files with 1185 additions and 244 deletions

View File

@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T

View File

@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass

View File

@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass

View File

@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
} // namespace impl
// TODO: glc/slc/...
template <index_t bytes>
template <index_t bytes, bool pre_nop = false>
struct buffer_load;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template <>
struct buffer_load<16>
template <bool pre_nop>
struct buffer_load<16, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<8>
template <bool pre_nop>
struct buffer_load<8, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<4>
template <bool pre_nop>
struct buffer_load<4, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<2>
template <bool pre_nop>
struct buffer_load<2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<1>
template <bool pre_nop>
struct buffer_load<1, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
template <index_t bytes, bool pre_nop = false>
struct buffer_load_if;
template <>
struct buffer_load_if<16>
template <bool pre_nop>
struct buffer_load_if<16, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<8>
template <bool pre_nop>
struct buffer_load_if<8, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<4>
template <bool pre_nop>
struct buffer_load_if<4, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<2>
template <bool pre_nop>
struct buffer_load_if<2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<1>
template <bool pre_nop>
struct buffer_load_if<1, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
@@ -294,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
@@ -315,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
@@ -336,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
@@ -357,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
@@ -378,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
@@ -402,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
@@ -431,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
@@ -439,14 +518,13 @@ struct buffer_store_if<8>
auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
@@ -461,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
@@ -490,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
@@ -519,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0)
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_load_if<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
buffer_load_if<sizeof(type), pre_nop>{}(dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
}
else
{
buffer_load<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
buffer_load<sizeof(type), pre_nop>{}(dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0)
index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset);
async_buffer_load_dword_v(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
}
template <index_t N,
@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0)
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size)
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// buffer_store requires:

View File

@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::);
}
CK_TILE_DEVICE void s_nop()
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
asm volatile("s_nop %0" : : "n"(cnt) :);
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(cnt);
#endif
}

View File

@@ -21,6 +21,7 @@
#define __gfx12__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif

View File

@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::generic;
@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
: p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::global;
@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
dst, p_data_, i, buffer_size_, is_valid_element);
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_);
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::lds;
@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::vgpr;

View File

@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load(lds_tile);
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)

View File

@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_;
};

View File

@@ -36,6 +36,8 @@ struct tensor_view
{
}
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
}
// X is vector of DataType.

View File

@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template <typename DstrTensors, index_t v>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{
constexpr index_t tensor_bytes =
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
using elem_type = typename DstrTensors::DataType;
constexpr index_t elem_size = sizeof(elem_type);
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v;
#endif
}
else
{
tile_elementwise_inout(
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
dstr_tensor);
tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
dstr_tensor);
}
}

View File

@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true>
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{});
bool_constant<oob_conditional_check>{},
pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord);
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;

View File

@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window);
k_dram_window,
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail

View File

@@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false;
else
os << delim;
os << static_cast<T>(v);
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
{
os << ck::type_convert<float>(v);
}
else
{
os << static_cast<T>(v);
}
}
return os;
}

View File

@@ -0,0 +1,352 @@
#pragma once
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace profiler {
template <typename DataType>
inline constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename OutElementOp,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
{
auto pass = true; // return status
using CShuffleDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
Tensor<InDataType> input(in_g_n_c_wis_desc);
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<CShuffleDataType> c(out_g_n_k_wos_desc);
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << host_output.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-1, 1});
break;
default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
// random scale values
auto scale_in = type_convert<float>(
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
auto scale_wei = type_convert<float>(
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
auto scale_out = type_convert<float>(
type_convert<f8_t>(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX)));
// initialize out_element_op for each iteration
const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out};
std::cout << "scale_in: " << scale_in << std::endl;
std::cout << "scale_wei: " << scale_wei << std::endl;
std::cout << "scale_out: " << scale_out << std::endl;
// run reference op
if(do_verification)
{
std::cout << "\nVerifying algorithm against reference convolution..." << std::endl;
std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7)
<< get_rtol<OutDataType>() << ", " << get_atol<OutDataType>() << ")" << std::endl;
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
PassThrough>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input,
weight,
c,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
PassThrough{});
c.SetZero();
ref_invoker.Run(ref_argument);
host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); });
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output,
host_output,
"Error: Device and Host results do not match!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
if(do_log)
{
LogRangeAsType<InDataType>(std::cout << "input : ", input.mData, ",")
<< std::endl;
LogRangeAsType<WeiDataType>(std::cout << "weight: ", weight.mData, ",")
<< std::endl;
LogRangeAsType<OutDataType>(
std::cout << "host_output : ", host_output.mData, ",")
<< std::endl;
LogRangeAsType<OutDataType>(
std::cout << "device_output: ", device_output.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
};
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{},
{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
run_impl(op_ptr, argument_ptr);
}
std::cout << "Best configuration parameters:"
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -57,6 +57,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
endif()
@@ -134,6 +135,8 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")

View File

@@ -0,0 +1,220 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "profiler_operation_registry.hpp"
#include <iostream>
enum struct ConvLayout
{
GNHWC_GKYXC_GNHWK = 0,
NHWGC_GKYXC_NHWGK = 1
};
enum struct OutElementOp
{
ConvScale = 0,
ConvInvScale = 1
};
enum struct ConvDataType
{
F8_F8_F8 = 0,
BF8_BF8_F8 = 1,
F8_BF8_F8 = 2,
BF8_F8_F8 = 3
};
#define OP_NAME "grouped_conv_fwd_outelementop"
#define OP_DESC "Grouped Convolution Forward+Elementwise Operation"
static void print_helper_msg()
{
// clang-format off
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: Input fp8, Weight fp8, Output fp8\n"
<< " 1: Input bf8, Weight bf8, Output fp8\n"
<< " 2: Input fp8, Weight bf8, Output fp8\n"
<< " 3: Input bf8, Weight fp8, Output fp8)\n"
<< "arg3: element-wise operation (0: ConvScale\n"
<< " 1: ConvInvScale)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
// clang-format on
}
int grouped_conv_fwd_outelementop(int argc, char* argv[])
{
// 9 total, 1 for num_dim_spatial
if(argc < 10)
{
print_helper_msg();
return 1;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto op = static_cast<OutElementOp>(std::stoi(argv[3]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const int init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0]
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale;
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp,
auto in_layout,
auto wei_layout,
auto out_layout,
auto in_type,
auto wei_type,
auto out_type,
auto out_element_op,
auto a_compute_type,
auto b_compute_type) {
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using OutElementOp = decltype(out_element_op);
using AComputeType = decltype(a_compute_type);
using BComputeType = decltype(b_compute_type);
bool pass = ck::profiler::profile_grouped_conv_fwd_outelementop_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
OutElementOp,
AComputeType,
BComputeType>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
};
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(op == OutElementOp::ConvScale)
{
if(data_type == ConvDataType::F8_F8_F8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvScale{}, F8{}, F8{});
}
else if(data_type == ConvDataType::BF8_BF8_F8)
{
return profile(I3,
NDHWGC{},
GKZYXC{},
NDHWGK{},
BF8{},
BF8{},
F8{},
ConvScale{},
BF8{},
BF8{});
}
else if(data_type == ConvDataType::F8_BF8_F8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, ConvScale{}, F8{}, BF8{});
}
else if(data_type == ConvDataType::BF8_F8_F8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, ConvScale{}, BF8{}, F8{});
}
}
else if(op == OutElementOp::ConvInvScale)
{
if(data_type == ConvDataType::F8_F8_F8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{});
}
else if(data_type == ConvDataType::BF8_BF8_F8)
{
return profile(I3,
NDHWGC{},
GKZYXC{},
NDHWGK{},
BF8{},
BF8{},
F8{},
ConvInvScale{},
BF8{},
BF8{});
}
else if(data_type == ConvDataType::F8_BF8_F8)
{
return profile(I3,
NDHWGC{},
GKZYXC{},
NDHWGK{},
F8{},
BF8{},
F8{},
ConvInvScale{},
F8{},
BF8{});
}
else if(data_type == ConvDataType::BF8_F8_F8)
{
return profile(I3,
NDHWGC{},
GKZYXC{},
NDHWGK{},
BF8{},
F8{},
F8{},
ConvInvScale{},
BF8{},
F8{});
}
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_outelementop);

View File

@@ -0,0 +1,20 @@
#!/bin/bash
## GPU visibility
export HIP_VISIBLE_DEVICES=0
DRIVER="../build/bin/ckProfiler"
OP=$1
DATATYPE=$2
OUTELEMENTOP=$3
LAYOUT=$4
VERIFY=$5
INIT=$6
LOG=$7
TIME=$8
N=$9
####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx
$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1
$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1