diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0160915a54..0df115dc3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -271,7 +271,9 @@ class FmhaBwdApiPool: per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) # GEMM0: Q@K=S^T diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1486671f6b..137d3a2f70 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -278,6 +278,9 @@ class FmhaFwdApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 419fbaaea8..5093945095 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 2cd8bb5f01..7f488d1b71 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer> { using payloa } // namespace impl // TODO: glc/slc/... -template +template struct buffer_load; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) -template <> -struct buffer_load<16> +template +struct buffer_load<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<8> +template +struct buffer_load<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<4> +template +struct buffer_load<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<2> +template +struct buffer_load<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<1> +template +struct buffer_load<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template +template struct buffer_load_if; -template <> -struct buffer_load_if<16> +template +struct buffer_load_if<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<8> +template +struct buffer_load_if<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<4> +template +struct buffer_load_if<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<2> +template +struct buffer_load_if<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<1> +template +struct buffer_load_if<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" @@ -294,17 +379,16 @@ struct buffer_store<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; - asm volatile( - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -315,17 +399,16 @@ struct buffer_store<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; - asm volatile( - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -336,17 +419,16 @@ struct buffer_store<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -357,17 +439,16 @@ struct buffer_store<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); using mbuf_t = short; - asm volatile( - "buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -378,17 +459,16 @@ struct buffer_store<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -402,21 +482,20 @@ struct buffer_store_if<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -431,7 +510,7 @@ struct buffer_store_if<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { @@ -439,14 +518,13 @@ struct buffer_store_if<8> auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -461,21 +539,20 @@ struct buffer_store_if<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -490,21 +567,20 @@ struct buffer_store_if<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -519,21 +595,20 @@ struct buffer_store_if<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -CK_TILE_DEVICE void async_buffer_load_dword(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); + else + asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, @@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, using type = thread_buffer; if constexpr(oob_conditional_check) { - buffer_load_if{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load_if{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } else { - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } } template + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0) + index_t src_immediate_addr_offset = 0, + bool_constant = {}) { static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, - index_t is_valid_element = 0) + index_t is_valid_element = 0, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t is_valid_element = 0, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, // buffer_load OOB still working. template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_element_space_size) + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); @@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); } // buffer_store requires: diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 4a69f67ae3..65a3a4e2ff 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() " ::); } -CK_TILE_DEVICE void s_nop() +CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 - asm volatile("\ - s_nop 0 \n \ - " ::); + asm volatile("s_nop %0" : : "n"(cnt) :); #else - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(cnt); #endif } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 83637e18e4..fa28aa2be9 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -21,6 +21,7 @@ #define __gfx12__ #endif +#include "hip/hip_version.h" #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -147,6 +148,14 @@ #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #endif +#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE +#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091 +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ffe8f7a4fd..ed705c91e7 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -69,6 +69,8 @@ struct buffer_view invalid_element_value_ = T{0}; CK_TILE_HOST_DEVICE constexpr buffer_view() - : p_data_{}, buffer_size_{}, invalid_element_value_{} + : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size, T invalid_element_value) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + : p_data_{p_data}, + buffer_size_{buffer_size}, + cached_buf_res_{0}, + invalid_element_value_{invalid_element_value} { } + // this is non constexpr intentially (will call some intrinsic internally) + // Must call for buffers that need *_raw load/store + CK_TILE_HOST_DEVICE void init_raw() + { + cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); + } + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() { return address_space_enum::global; @@ -333,12 +346,15 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t i, + bool is_valid_element, + bool_constant = {}) const { constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -349,18 +365,21 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - dst, p_data_, i, buffer_size_, is_valid_element); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check, pre_nop>( + dst, cached_buf_res_, i, is_valid_element, bool_constant{}); } // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, + index_t i, + bool /*is_valid_element*/, + bool_constant = {}) const { // X is vector of T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -371,8 +390,8 @@ struct buffer_view, t_per_x, Coherence>( - smem, p_data_, i, buffer_size_); + amd_async_buffer_load_with_oob_raw, t_per_x, Coherence>( + smem, cached_buf_res_, i, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -627,6 +646,8 @@ struct buffer_view + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto load_tile_raw(T& tile, const tile_window_with_static_distribution& tile_window, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}); + tile_window.load_raw(tile, bool_constant{}, bool_constant{}); } template + index_t NumCoord, + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const tile_window_with_static_distribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}, + bool_constant = {}) { - return tile_window.async_load(lds_tile); + return tile_window.async_load_raw( + lds_tile, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 89806203ab..9707f2990a 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -35,6 +35,8 @@ struct null_tile_window CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + CK_TILE_DEVICE void init_raw() {} + WindowLengths window_lengths_; }; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 656309532e..4655eec241 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -36,6 +36,8 @@ struct tensor_view { } + CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); } + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() @@ -85,30 +87,34 @@ struct tensor_view // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE void - get_vectorized_elements_raw(remove_cvref_t& dst, - const TensorCoord& coord, - bool_constant = {}) const + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}, + bool_constant = {}) const { - return buf_.template get_raw( + return buf_.template get_raw( dst, coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); } template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, - const TensorCoord& coord) const + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( + remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const { - return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + return buf_.template async_get_raw( + smem, coord.get_offset(), true /*not used*/, bool_constant{}); } // X is vector of DataType. diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 5fecd19dcd..79018b9ced 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // sub-dword tensor... -template -CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +template +CK_TILE_DEVICE void +set_tile(DstrTensors& dstr_tensor, number, bool_constant = {}) { - constexpr index_t tensor_bytes = - DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); - if constexpr(v == 0 && tensor_bytes % 4 == 0) + using elem_type = typename DstrTensors::DataType; + constexpr index_t elem_size = sizeof(elem_type); + + constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size; + + // # bytes per write = 4 + if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt) { +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + auto& buffer = dstr_tensor.get_thread_buffer(); + + static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) { + if constexpr(elem_size == 1) + { + // # elements per write = 4 + constexpr auto values = ext_vector_t{0, 0, 0, 0}; + + buffer[i_write * 4 + 0] = values.x; + buffer[i_write * 4 + 1] = values.y; + buffer[i_write * 4 + 2] = values.z; + buffer[i_write * 4 + 3] = values.w; + } + else if constexpr(elem_size == 2) + { + // # elements per write = 2 + constexpr auto values = ext_vector_t{0, 0}; + + buffer[i_write * 2 + 0] = values.x; + buffer[i_write * 2 + 1] = values.y; + } + else if constexpr(elem_size == 4) + { + // # elements per write = 1 + constexpr elem_type value = 0; + + buffer[i_write] = value; + } + else + { + static_assert(false, "type not supported"); + } + }); +#else using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; +#endif } else { - tile_elementwise_inout( - [](auto& x) { x = type_convert(v); }, - dstr_tensor); + tile_elementwise_inout([](auto& x) { x = type_convert(v); }, + dstr_tensor); } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2c38c6aa2c..70f381db1f 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -344,9 +344,10 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -373,7 +374,13 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); @@ -384,7 +391,8 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, - bool_constant{}); + bool_constant{}, + pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -399,12 +407,17 @@ struct tile_window_with_static_distribution } }); }); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile("; this inline asm is workaround to prevent compiler from using too much " + "scratch memory" ::); +#endif } // TODO: currently async load only implemented in inline asm - template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - bool_constant = {}) const + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -449,11 +462,17 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord); + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -668,6 +687,67 @@ struct tile_window_with_static_distribution }); } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e9a14ca5ac..8251627e6c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync return Problem::kBlockPerCu; else { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + if constexpr(kK0BlockLength <= 32) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && @@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well @@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); + k_dram_window, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync { // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index ddbd16ad9a..493b992aca 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) first = false; else os << delim; - os << static_cast(v); + + if constexpr(std::is_same_v || std::is_same_v) + { + os << ck::type_convert(v); + } + else + { + os << static_cast(v); + } } return os; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp new file mode 100644 index 0000000000..bd756eb825 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -0,0 +1,352 @@ +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace profiler { + +template +inline constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + auto pass = true; // return status + + using CShuffleDataType = float; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = PassThrough; + using WeiElementOp = PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // random scale values + auto scale_in = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_wei = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_out = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // run reference op + if(do_verification) + { + + std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; + std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7) + << get_rtol() << ", " << get_atol() << ")" << std::endl; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c2a9769727..198f49432f 100755 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -57,6 +57,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) endif() @@ -134,6 +135,8 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) endif() if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") diff --git a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp new file mode 100644 index 0000000000..196a2cf3f2 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp @@ -0,0 +1,220 @@ +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK = 0, + NHWGC_GKYXC_NHWGK = 1 +}; + +enum struct OutElementOp +{ + ConvScale = 0, + ConvInvScale = 1 +}; + +enum struct ConvDataType +{ + F8_F8_F8 = 0, + BF8_BF8_F8 = 1, + F8_BF8_F8 = 2, + BF8_F8_F8 = 3 +}; + +#define OP_NAME "grouped_conv_fwd_outelementop" +#define OP_DESC "Grouped Convolution Forward+Elementwise Operation" + +static void print_helper_msg() +{ + // clang-format off + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp8, Weight fp8, Output fp8\n" + << " 1: Input bf8, Weight bf8, Output fp8\n" + << " 2: Input fp8, Weight bf8, Output fp8\n" + << " 3: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: element-wise operation (0: ConvScale\n" + << " 1: ConvInvScale)\n" + << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_outelementop(int argc, char* argv[]) +{ + + // 9 total, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto op = static_cast(std::stoi(argv[3])); + const auto layout = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0] + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using ConvScale = ck::tensor_operation::element_wise::ConvScale; + using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto out_element_op, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using OutElementOp = decltype(out_element_op); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(op == OutElementOp::ConvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, ConvScale{}, F8{}, BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, ConvScale{}, BF8{}, F8{}); + } + } + else if(op == OutElementOp::ConvInvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvInvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + BF8{}, + F8{}, + ConvInvScale{}, + F8{}, + BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + F8{}, + F8{}, + ConvInvScale{}, + BF8{}, + F8{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_outelementop); diff --git a/script/profile_grouped_conv_fwd_outelementop.sh b/script/profile_grouped_conv_fwd_outelementop.sh new file mode 100755 index 0000000000..ac444a25c2 --- /dev/null +++ b/script/profile_grouped_conv_fwd_outelementop.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +OUTELEMENTOP=$3 +LAYOUT=$4 +VERIFY=$5 +INIT=$6 +LOG=$7 +TIME=$8 + +N=$9 + +####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1 +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1