mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] support group from cmdline (#1295)
* support cmdline seqlen decode * silent print * update readme * update kernel launch 3d * update tile partitioner * fix spill for bf16 * modify based on comment * modify payload_t * fix bug for alibi mode * fix alibi test err * refactor kernel launch, support select timer * add missing file * remove useless code * add some comments
This commit is contained in:
@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
return __builtin_bit_cast(int32x4_t, res);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
template<index_t N, typename T> struct buffer_load_trait;
|
||||
|
||||
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
|
||||
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
|
||||
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
|
||||
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
|
||||
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
|
||||
|
||||
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
|
||||
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
|
||||
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
|
||||
#endif
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
// TODO: glc/slc/...
|
||||
template <index_t bytes>
|
||||
struct buffer_load;
|
||||
@@ -48,7 +67,7 @@ struct buffer_load<16>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -68,7 +87,7 @@ struct buffer_load<8>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = fp32x2_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -88,7 +107,7 @@ struct buffer_load<4>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -108,7 +127,7 @@ struct buffer_load<2>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -128,7 +147,7 @@ struct buffer_load<1>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -152,7 +171,7 @@ struct buffer_load_if<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
@@ -177,7 +196,7 @@ struct buffer_load_if<8>
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x2_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -201,7 +220,7 @@ struct buffer_load_if<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -225,7 +244,7 @@ struct buffer_load_if<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -249,7 +268,7 @@ struct buffer_load_if<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
|
||||
|
||||
Reference in New Issue
Block a user