diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9c2907778f..9f1e0f6948 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - - printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " - "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", - fmha_traits.hdim_q, - fmha_traits.hdim_v, - fmha_traits.data_type.c_str(), - fmha_traits.is_group_mode, - static_cast(fmha_traits.mask_type), - static_cast(fmha_traits.bias_type), - fmha_traits.has_dbias, - fmha_traits.has_dropout, - fmha_traits.is_store_randval, - fmha_traits.is_deterministic); - fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07be65a150..037e86909d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1276,26 +1276,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); -template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1766,15 +1786,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1536,15 +1556,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template +CK_TILE_DEVICE index_t get_warp_id(bool_constant = {}) { - return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); + const index_t warp_id = threadIdx.x / get_warp_size(); + if constexpr(ReturnSgpr) + { + return __builtin_amdgcn_readfirstlane(warp_id); + } + else + { + return warp_id; + } } CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ad5902f16e..f5ddcd278c 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -288,8 +288,11 @@ struct tile_window_with_static_distribution sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + // Use VALU so the compiler can optimize redundant/repeated computations + const index_t m0_init_value = + size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant{}); + m0_set_with_memory( + __builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 81075d0ec6..66f51459af 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel {0, i_n0}); // If kApplyRoPe is false, we set the rotary_dim to 0 - auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0; - + auto rotary_dim = [&]() { + if constexpr(kApplyRoPE) + return kargs.rotary_dim; + else + return 0; + }(); FmhaPipeline{}(q_dram_window, k_dram_window, i_page_block_k, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 1d95bc2801..9a31498dd1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {seqlen_q_start, bias_origin.at(number<1>{})}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 65f70c4f62..3112070271 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {bias_origin.at(number<0>{}), seqlen_kv_start}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 7849c931f7..6259e5b473 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor() { - return MakeXLdsWriteBlockDescriptor(); + return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor(); } template @@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kQKHeaddim>(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor() - { - return MakeXLdsReadBlockDescriptor(); - } template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() @@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t N1 = min(static_cast(GetAlignmentBias()), - kMPerBlock * kNPerBlock / kBlockSize); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = get_warp_size() / N0; - constexpr index_t M2 = kMPerBlock / M1 / M0; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution(); } template @@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return sizeof(typename Problem::BiasDataType) * - MakeBiasLdsWriteBlockDescriptor().get_element_space_size(); + MakeBiasLdsBlockDescriptor().get_element_space_size(); else return 0; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index c492ce6827..ff1f31edc8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -364,7 +364,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; if constexpr(AsyncCopy) { - return 4 / sizeof(KDataType); +#if defined(__gfx950__) + constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4 +#else + constexpr index_t MaxLoadSizeInBytes = 4; // dword +#endif + + return MaxLoadSizeInBytes / sizeof(KDataType); } else {