From fca10f32d9e483021ebb91edf6b8109dfa396360 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 9 Jul 2024 02:09:55 +0800 Subject: [PATCH] [CK_TILE] wa prec, remove sgpr offset for inline asm (#1356) * wa prec, remove sgpr offset for inline asm * macro for set tile * ignore unused param if no kernel instances in host API * fix more prec issue * cache buffer resource * fix * support pre-nop * clear tile by vector type members * add workaround to reduce scratch memory * conditionally enable workaround code * enable workaround start from certain build version * fallback set_tile() implementation from certain build version * undo template argument changes * put dummy asm in load_raw() * fix comments, refactor s_nop inside buffer_load --------- Co-authored-by: PoYen, Chen [ROCm/composable_kernel commit: 8182976c37433808b5e3a27a6536d1b74b0c23a1] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 + .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 3 + .../core/arch/amd_buffer_addressing.hpp | 518 +++++++++++------- include/ck_tile/core/arch/arch.hpp | 8 +- include/ck_tile/core/config.hpp | 9 + include/ck_tile/core/tensor/buffer_view.hpp | 45 +- include/ck_tile/core/tensor/load_tile.hpp | 19 +- .../ck_tile/core/tensor/null_tile_window.hpp | 2 + include/ck_tile/core/tensor/tensor_view.hpp | 24 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 56 +- include/ck_tile/core/tensor/tile_window.hpp | 100 +++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 33 +- 13 files changed, 581 insertions(+), 243 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0160915a54..0df115dc3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -271,7 +271,9 @@ class FmhaBwdApiPool: per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) # GEMM0: Q@K=S^T diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1486671f6b..137d3a2f70 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -278,6 +278,9 @@ class FmhaFwdApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 419fbaaea8..5093945095 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 2cd8bb5f01..7f488d1b71 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer> { using payloa } // namespace impl // TODO: glc/slc/... -template +template struct buffer_load; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) -template <> -struct buffer_load<16> +template +struct buffer_load<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<8> +template +struct buffer_load<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<4> +template +struct buffer_load<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<2> +template +struct buffer_load<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<1> +template +struct buffer_load<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template +template struct buffer_load_if; -template <> -struct buffer_load_if<16> +template +struct buffer_load_if<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<8> +template +struct buffer_load_if<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<4> +template +struct buffer_load_if<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<2> +template +struct buffer_load_if<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<1> +template +struct buffer_load_if<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" @@ -294,17 +379,16 @@ struct buffer_store<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; - asm volatile( - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -315,17 +399,16 @@ struct buffer_store<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; - asm volatile( - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -336,17 +419,16 @@ struct buffer_store<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -357,17 +439,16 @@ struct buffer_store<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); using mbuf_t = short; - asm volatile( - "buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -378,17 +459,16 @@ struct buffer_store<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -402,21 +482,20 @@ struct buffer_store_if<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -431,7 +510,7 @@ struct buffer_store_if<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { @@ -439,14 +518,13 @@ struct buffer_store_if<8> auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -461,21 +539,20 @@ struct buffer_store_if<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -490,21 +567,20 @@ struct buffer_store_if<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -519,21 +595,20 @@ struct buffer_store_if<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -CK_TILE_DEVICE void async_buffer_load_dword(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); + else + asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, @@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, using type = thread_buffer; if constexpr(oob_conditional_check) { - buffer_load_if{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load_if{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } else { - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } } template + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0) + index_t src_immediate_addr_offset = 0, + bool_constant = {}) { static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, - index_t is_valid_element = 0) + index_t is_valid_element = 0, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t is_valid_element = 0, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, // buffer_load OOB still working. template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_element_space_size) + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); @@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); } // buffer_store requires: diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 4a69f67ae3..65a3a4e2ff 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() " ::); } -CK_TILE_DEVICE void s_nop() +CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 - asm volatile("\ - s_nop 0 \n \ - " ::); + asm volatile("s_nop %0" : : "n"(cnt) :); #else - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(cnt); #endif } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 83637e18e4..fa28aa2be9 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -21,6 +21,7 @@ #define __gfx12__ #endif +#include "hip/hip_version.h" #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -147,6 +148,14 @@ #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #endif +#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE +#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091 +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ffe8f7a4fd..ed705c91e7 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -69,6 +69,8 @@ struct buffer_view invalid_element_value_ = T{0}; CK_TILE_HOST_DEVICE constexpr buffer_view() - : p_data_{}, buffer_size_{}, invalid_element_value_{} + : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size, T invalid_element_value) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + : p_data_{p_data}, + buffer_size_{buffer_size}, + cached_buf_res_{0}, + invalid_element_value_{invalid_element_value} { } + // this is non constexpr intentially (will call some intrinsic internally) + // Must call for buffers that need *_raw load/store + CK_TILE_HOST_DEVICE void init_raw() + { + cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); + } + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() { return address_space_enum::global; @@ -333,12 +346,15 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t i, + bool is_valid_element, + bool_constant = {}) const { constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -349,18 +365,21 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - dst, p_data_, i, buffer_size_, is_valid_element); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check, pre_nop>( + dst, cached_buf_res_, i, is_valid_element, bool_constant{}); } // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, + index_t i, + bool /*is_valid_element*/, + bool_constant = {}) const { // X is vector of T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -371,8 +390,8 @@ struct buffer_view, t_per_x, Coherence>( - smem, p_data_, i, buffer_size_); + amd_async_buffer_load_with_oob_raw, t_per_x, Coherence>( + smem, cached_buf_res_, i, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -627,6 +646,8 @@ struct buffer_view + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto load_tile_raw(T& tile, const tile_window_with_static_distribution& tile_window, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}); + tile_window.load_raw(tile, bool_constant{}, bool_constant{}); } template + index_t NumCoord, + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const tile_window_with_static_distribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}, + bool_constant = {}) { - return tile_window.async_load(lds_tile); + return tile_window.async_load_raw( + lds_tile, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 89806203ab..9707f2990a 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -35,6 +35,8 @@ struct null_tile_window CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + CK_TILE_DEVICE void init_raw() {} + WindowLengths window_lengths_; }; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 656309532e..4655eec241 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -36,6 +36,8 @@ struct tensor_view { } + CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); } + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() @@ -85,30 +87,34 @@ struct tensor_view // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE void - get_vectorized_elements_raw(remove_cvref_t& dst, - const TensorCoord& coord, - bool_constant = {}) const + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}, + bool_constant = {}) const { - return buf_.template get_raw( + return buf_.template get_raw( dst, coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); } template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, - const TensorCoord& coord) const + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( + remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const { - return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + return buf_.template async_get_raw( + smem, coord.get_offset(), true /*not used*/, bool_constant{}); } // X is vector of DataType. diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 5fecd19dcd..79018b9ced 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // sub-dword tensor... -template -CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +template +CK_TILE_DEVICE void +set_tile(DstrTensors& dstr_tensor, number, bool_constant = {}) { - constexpr index_t tensor_bytes = - DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); - if constexpr(v == 0 && tensor_bytes % 4 == 0) + using elem_type = typename DstrTensors::DataType; + constexpr index_t elem_size = sizeof(elem_type); + + constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size; + + // # bytes per write = 4 + if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt) { +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + auto& buffer = dstr_tensor.get_thread_buffer(); + + static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) { + if constexpr(elem_size == 1) + { + // # elements per write = 4 + constexpr auto values = ext_vector_t{0, 0, 0, 0}; + + buffer[i_write * 4 + 0] = values.x; + buffer[i_write * 4 + 1] = values.y; + buffer[i_write * 4 + 2] = values.z; + buffer[i_write * 4 + 3] = values.w; + } + else if constexpr(elem_size == 2) + { + // # elements per write = 2 + constexpr auto values = ext_vector_t{0, 0}; + + buffer[i_write * 2 + 0] = values.x; + buffer[i_write * 2 + 1] = values.y; + } + else if constexpr(elem_size == 4) + { + // # elements per write = 1 + constexpr elem_type value = 0; + + buffer[i_write] = value; + } + else + { + static_assert(false, "type not supported"); + } + }); +#else using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; +#endif } else { - tile_elementwise_inout( - [](auto& x) { x = type_convert(v); }, - dstr_tensor); + tile_elementwise_inout([](auto& x) { x = type_convert(v); }, + dstr_tensor); } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2c38c6aa2c..70f381db1f 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -344,9 +344,10 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -373,7 +374,13 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); @@ -384,7 +391,8 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, - bool_constant{}); + bool_constant{}, + pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -399,12 +407,17 @@ struct tile_window_with_static_distribution } }); }); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile("; this inline asm is workaround to prevent compiler from using too much " + "scratch memory" ::); +#endif } // TODO: currently async load only implemented in inline asm - template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - bool_constant = {}) const + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -449,11 +462,17 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord); + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -668,6 +687,67 @@ struct tile_window_with_static_distribution }); } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e9a14ca5ac..8251627e6c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync return Problem::kBlockPerCu; else { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + if constexpr(kK0BlockLength <= 32) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && @@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well @@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); + k_dram_window, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync { // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail