[CK_TILE] support group from cmdline (#1295)

* support cmdline seqlen decode

* silent print

* update readme

* update kernel launch 3d

* update tile partitioner

* fix spill for bf16

* modify based on comment

* modify payload_t

* fix bug for alibi mode

* fix alibi test err

* refactor kernel launch, support select timer

* add missing file

* remove useless code

* add some comments
This commit is contained in:
carlushuang
2024-05-28 11:13:21 +08:00
committed by GitHub
parent 02fa2c298b
commit 5055b3bdcb
16 changed files with 534 additions and 261 deletions

View File

@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return __builtin_bit_cast(int32x4_t, res);
}
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct buffer_load_trait;
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
#endif
// clang-format on
} // namespace impl
// TODO: glc/slc/...
template <index_t bytes>
struct buffer_load;
@@ -48,7 +67,7 @@ struct buffer_load<16>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -68,7 +87,7 @@ struct buffer_load<8>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -88,7 +107,7 @@ struct buffer_load<4>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -108,7 +127,7 @@ struct buffer_load<2>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -128,7 +147,7 @@ struct buffer_load<1>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -152,7 +171,7 @@ struct buffer_load_if<16>
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
@@ -177,7 +196,7 @@ struct buffer_load_if<8>
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t;
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
@@ -201,7 +220,7 @@ struct buffer_load_if<4>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
@@ -225,7 +244,7 @@ struct buffer_load_if<2>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
@@ -249,7 +268,7 @@ struct buffer_load_if<1>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"