From 7bdf6a7eef84d254cdcea1af01402307c566e6fe Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 22 Aug 2025 03:15:51 +0000 Subject: [PATCH] merge develop and solve conflicts --- .../ck_tile/18_flatmm/run_flatmm_example.inc | 4 +- ...n_grouped_convolution_bwd_data_example.inc | 62 +- .../core/arch/amd_buffer_addressing.hpp | 71 +- .../arch/amd_buffer_addressing_builtins.hpp | 75 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 1123 +++++++++-------- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 8 +- 6 files changed, 675 insertions(+), 668 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 013db6715d..ff1a239cba 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -40,8 +40,8 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc index 3e1c13c833..d1cf4fade7 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -11,17 +11,17 @@ template float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, - int n_warmup, - int n_repeat) + int n_warmup, + int n_repeat) { float ave_time = grouped_conv_bwd_data( + InDataType, + WeiDataType, + AccDataType, + OutDataType, + InLayout, + WeiLayout, + OutLayout>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -124,11 +124,11 @@ int run_grouped_conv_bwd_data_example_with_layouts( output_dev_buf.ToDevice(output.data()); ck_tile::GroupedConvBwdDataHostArgs args(conv_param, - input_dev_buf.GetDeviceBuffer(), - weight_dev_buf.GetDeviceBuffer(), - {}, - output_dev_buf.GetDeviceBuffer(), - kbatch); + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl; std::cout << "input: " << input.mDesc << std::endl; @@ -136,13 +136,13 @@ int run_grouped_conv_bwd_data_example_with_layouts( std::cout << "output: " << output.mDesc << std::endl; invoke_grouped_conv_bwd_data(args, n_warmup, n_repeat); + InDataType, + WeiDataType, + AccDataType, + OutDataType, + InLayout, + WeiLayout, + OutLayout>(args, n_warmup, n_repeat); input_dev_buf.FromDevice(input.data()); bool pass = true; @@ -152,17 +152,15 @@ int run_grouped_conv_bwd_data_example_with_layouts( ck_tile::HostTensor input_host_ref(in_g_n_c_wis_desc); input_host_ref.SetZero(); - ck_tile:: - reference_grouped_conv_bwd_data( - input_host_ref, - weight, - output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_); - const ck_tile::index_t GemmK = - weight.get_element_size() / (conv_param.G_ * conv_param.K_); + ck_tile::reference_grouped_conv_bwd_data( + input_host_ref, + weight, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); const float max_accumulated_value = *std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end()); const auto rtol_atol = diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index baf42413d0..01efe61387 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1801,18 +1801,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, } _Pragma("clang diagnostic push") -_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") -template -CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - index_t flag = 0, - bool_constant = {}) + _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template < + typename T, + index_t N, + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool oob_conditional_check = true> + CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0, + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; @@ -1835,23 +1835,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - (as3_uint32_ptr)(smem), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } _Pragma("clang diagnostic pop") -template -CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) + template + CK_TILE_DEVICE + void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); @@ -2787,11 +2787,10 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_ #endif } -_Pragma("clang diagnostic push") -_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") +_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") #if defined(__gfx950__) -template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) + template + __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), @@ -2801,8 +2800,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = (__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr); - //reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) @@ -2810,8 +2809,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = (__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr; - //reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t>) @@ -2819,8 +2818,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = (__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr; - //reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 271ca2ee16..1d682f4f5d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1571,18 +1571,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, } _Pragma("clang diagnostic push") -_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") -template -CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - index_t flag = 0, - bool_constant = {}) + _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template < + typename T, + index_t N, + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool oob_conditional_check = true> + CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0, + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; @@ -1605,23 +1605,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - (as3_uint32_ptr)(smem), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } _Pragma("clang diagnostic pop") -template -CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) + template + CK_TILE_DEVICE + void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); @@ -2597,20 +2597,17 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, static_assert(bytes_per_thread == dword_bytes); #endif // LDS pointer must be attributed with the LDS address space. - as3_uint32_ptr lds_ptr = - (as3_uint32_ptr)(lds_base_ptr + lds_offset); + as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_base_ptr + lds_offset); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } - -_Pragma("clang diagnostic push") -_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") +_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") #if defined(__gfx950__) -template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) + template + __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), @@ -2620,8 +2617,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = (__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr); - //reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) @@ -2629,8 +2626,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = (__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr; - //reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t>) @@ -2638,8 +2635,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = (__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr; - //reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( - // reinterpret_cast(in_ptr)); + // reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + // reinterpret_cast(in_ptr)); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 84eea8d119..41b716ff64 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -187,624 +187,637 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // LDS allocation const auto smem_ptr_ = reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic - auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr, - VDataType* __restrict__ v_lds_ptr, - OGradDataType* __restrict__ do_lds_ptr0, - OGradDataType* __restrict__ do_lds_ptr1, - QDataType* __restrict__ q_lds_ptr0, - QDataType* __restrict__ q_lds_ptr1, - LSEDataType* __restrict__ lse_lds_ptr, - DDataType* __restrict__ d_lds_ptr, - GemmDataType* __restrict__ ds_lds_ptr, - BiasDataType* __restrict__ bias_lds_ptr) - { -/* - const auto k_lds_ptr = reinterpret_cast(smem_ptr_); - const auto v_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeK()); + auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr, + VDataType* __restrict__ v_lds_ptr, + OGradDataType* __restrict__ do_lds_ptr0, + OGradDataType* __restrict__ do_lds_ptr1, + QDataType* __restrict__ q_lds_ptr0, + QDataType* __restrict__ q_lds_ptr1, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, + GemmDataType* __restrict__ ds_lds_ptr, + BiasDataType* __restrict__ bias_lds_ptr) { + /* + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); - const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); - const auto do_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr0 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ()); - const auto lse_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); - const auto d_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE()); - const auto ds_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); - const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); -*/ - auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); - auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); const auto do_lds_ptr1 = reinterpret_cast( smem_ptr_ + Policy::template GetSmemSizeOGrad()); const auto + q_lds_ptr0 = reinterpret_cast( smem_ptr_ + Policy::template + GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad()); const + auto q_lds_ptr1 = reinterpret_cast( smem_ptr_ + + Policy::template GetSmemSizeOGrad() + Policy::template + GetSmemSizeOGrad() + Policy::template GetSmemSizeQ()); const auto + lse_lds_ptr = reinterpret_cast( smem_ptr_ + + Policy::template GetSmemSizeOGrad() + Policy::template + GetSmemSizeOGrad() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ()); const auto d_lds_ptr = + reinterpret_cast( smem_ptr_ + Policy::template + GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template + GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); const auto + ds_lds_ptr = reinterpret_cast( smem_ptr_ + + Policy::template GetSmemSizeOGrad() + Policy::template + GetSmemSizeOGrad() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD()); const auto bias_lds_ptr = + reinterpret_cast(ds_lds_ptr); + */ + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - //------------------------------------------------------------------ - // V, HBM ->LDS ->Reg - auto v_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - v_dram_block_window_tmp.get_bottom_tensor_view()), - v_dram_block_window_tmp.get_window_lengths(), - v_dram_block_window_tmp.get_window_origin(), - Policy::template MakeVDramTileDistribution()); - auto v_lds = make_tensor_view( - v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); - auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + v_dram_block_window_tmp.get_bottom_tensor_view()), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); - //------------------------------------------------------------------ - // KT, HBM -> LDS --trload-->Reg - async_load_tile(k_lds_write_window, k_dram_window); - async_load_tile(v_lds_write_window, v_dram_window); - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); + //------------------------------------------------------------------ + // KT, HBM -> LDS --trload-->Reg + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); - //------------------------------------------------------------------ - // Pre-Load KV into Registers - auto k_lds_read = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor()); - auto k_lds_read_window = - make_tile_window(k_lds_read, - make_tuple(number{}, number{}), - k_lds_write_window.get_window_origin(), - Policy::template MakeKRegBlockDescriptor()); - auto k_reg_tensor = load_tile(k_lds_read_window); + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_lds_read = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor()); + auto k_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + auto k_reg_tensor = load_tile(k_lds_read_window); - auto kt_lds_read_window = - make_tile_window(k_lds_read, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKTRegBlockDescriptor()); + auto kt_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); - auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window); - auto v_lds_read = make_tensor_view( - v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor()); - auto v_lds_read_window = - make_tile_window(v_lds_read, - make_tuple(number{}, number{}), - v_lds_write_window.get_window_origin(), - Policy::template MakeVRegBlockDescriptor()); - auto v_reg_tensor = load_tile(v_lds_read_window); + auto v_lds_read = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor()); + auto v_lds_read_window = + make_tile_window(v_lds_read, + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + auto v_reg_tensor = load_tile(v_lds_read_window); - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); - //---------------------------- Loop Load in ----------------------------// - // Q: HBM -->LDS - auto q_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - q_dram_block_window_tmp.get_bottom_tensor_view()), - q_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}, - Policy::template MakeQDramTileDistribution()); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + //---------------------------- Loop Load in ----------------------------// + // Q: HBM -->LDS + auto q_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + q_dram_block_window_tmp.get_bottom_tensor_view()), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); - auto q_lds = make_tensor_view( - q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor()); - auto q_lds_write_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + auto q_lds = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor()); + auto q_lds_write_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); - auto q_lds_read = make_tensor_view( - q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor()); - auto q_lds_read_window = - make_tile_window(q_lds_read, - make_tuple(number{}, number{}), - q_lds_write_window.get_window_origin(), - Policy::template MakeQRegSliceBlockDescriptor()); - auto qt_lds_read_window = - make_tile_window(q_lds_read, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQTRegSliceBlockDescriptor()); + auto q_lds_read = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor()); + auto q_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + q_lds_write_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + auto qt_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); - // dO: HBM ->LDS ---load--> Reg - // dOT: \-loadtr-> Reg - auto do_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - do_dram_block_window_tmp.get_bottom_tensor_view()), - do_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}, - Policy::template MakeOGradDramTileDistribution()); + // dO: HBM ->LDS ---load--> Reg + // dOT: \-loadtr-> Reg + auto do_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + do_dram_block_window_tmp.get_bottom_tensor_view()), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); - auto do_lds = make_tensor_view( - do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor()); - auto do_lds_write_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + auto do_lds = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor()); + auto do_lds_write_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); - auto do_lds_read = make_tensor_view( - do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor()); - auto do_lds_read_window = - make_tile_window(do_lds_read, - make_tuple(number{}, number{}), - do_lds_write_window.get_window_origin(), - Policy::template MakeOGradRegSliceBlockDescriptor()); - auto dot_lds_read_window = - make_tile_window(do_lds_read, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeOGradTRegSliceBlockDescriptor()); + auto do_lds_read = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor()); + auto do_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + do_lds_write_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + auto dot_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); - // dS: Reg -> Reg -> LDS - auto ds_lds = make_tensor_view( - ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + // dS: Reg -> Reg -> LDS + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); - auto ds_lds_window = - make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); - // transform it to make it from col-major to row-major; prepared for load_tile_transpose - auto ds_lds_t = make_tensor_view( - ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); - auto ds_lds_read_window = - make_tile_window(ds_lds_t, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeSGradRegSliceBlockDescriptor()); + // transform it to make it from col-major to row-major; prepared for load_tile_transpose + auto ds_lds_t = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_read_window = + make_tile_window(ds_lds_t, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeSGradRegSliceBlockDescriptor()); - // Bias: HBM ->Reg ->Reg ->LDS - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, bias_origin.at(number<1>{})}, - Policy::template MakeBiasTileDistribution()); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); - auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); - auto bias_lds_write_window = - make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_s_lds_read_window = - make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), - bias_lds_write_window.get_window_lengths(), - bias_lds_write_window.get_window_origin(), - Policy::template MakeBiasSTileDistribution()); + auto bias_s_lds_read_window = + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); - static_assert(std::is_same_v, - "BiasDataType and BiasGradDataType should be the same!"); + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); - // LSE: HBM -> LDS ->Reg - auto lse_dram_window = make_tile_window( - lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}, - Policy::template MakeLSEDDramTileDistribution()); + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); - auto lse_lds = make_tensor_view( - lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); - auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); - auto lse_lds_read_window = make_tile_window( - lse_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); - // D: HBM ->Reg - auto d_dram_window = make_tile_window( - d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}, - Policy::template MakeLSEDDramTileDistribution()); + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); - auto d_lds = make_tensor_view( - d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); - auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); - auto d_lds_read_window = make_tile_window( - d_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); - // RandVal: HBM ->Reg - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_q_start); + // RandVal: HBM ->Reg + auto randval_dram_window = + dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); - // BiasGrad - // Reg ->LDS ->Reg ->HBM - const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); - auto dbias_dram_window = - make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), - dbias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N - auto dbias_lds_read_window = - make_tile_window(bias_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeShuffledBiasTileDistribution()); + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); - // ----------------------------Loop write out------------------------------// - auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); - index_t i_total_loops = 0; - index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); - static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); - static_assert(kM0 == kK3, "kM0 should equal to kK3"); - constexpr index_t k4_loops = kN0 / kK4; + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; - clear_tile(dv_acc); - clear_tile(dk_acc); + clear_tile(dv_acc); + clear_tile(dk_acc); - __builtin_amdgcn_sched_barrier(0); - - decltype(load_tile(q_lds_read_window)) q_reg_tensor; - decltype(load_tile(lse_lds_read_window)) lse; - decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor; - decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next; - decltype(load_tile(do_lds_read_window)) do_reg_tensor; - decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor; - decltype(load_tile(d_lds_read_window)) d; - decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor; - decltype(gemm_0.MakeCBlockTile()) s_acc, p; - decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; - decltype(gemm_4.MakeCBlockTile()) dq_acc; - - decltype(load_tile(lse_dram_window)) lse_block_tile; - decltype(load_tile(d_dram_window)) d_block_tile; - - index_t i_total_bodys = 0; - auto main_body_impl = [&](auto is_prologue_, - auto is_epilogue_, - QDataType* const __restrict__ q_lds_ptr_curr, - QDataType* const __restrict__ q_lds_ptr_next, - OGradDataType* const __restrict__ do_lds_ptr_curr, - OGradDataType* const __restrict__ do_lds_ptr_next) mutable { - constexpr bool is_prologue = is_prologue_.value; - constexpr bool is_epilogue = is_epilogue_.value; - static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); - constexpr bool is_main_body = is_prologue && is_epilogue; - - if constexpr(is_prologue) - { - lse_block_tile = load_tile(lse_dram_window); - move_tile_window(lse_dram_window, {kM0}); - - d_block_tile = load_tile(d_dram_window); - move_tile_window(d_dram_window, {kM0}); - - q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); - async_load_tile(q_lds_write_window, q_dram_window); - move_tile_window(q_dram_window, {kM0, 0}); - - do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); - async_load_tile(do_lds_write_window, do_dram_window); - move_tile_window(do_dram_window, {kM0, 0}); - } - if constexpr(is_epilogue) - { - // STAGE 1, Q@K Gemm0 - s_acc = gemm_0(q_reg_tensor, k_reg_tensor); - - dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - } - if constexpr(is_main_body) - Policy::template HotLoopScheduler::SchedulerGemm0(); __builtin_amdgcn_sched_barrier(0); - if constexpr(is_epilogue) - { - // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const auto bias_tile = load_tile(bias_dram_window); - auto shuffled_bias_tile = make_static_distributed_tensor( - Policy::template MakeShuffledBiasTileDistribution()); - shuffle_tile(shuffled_bias_tile, bias_tile); - store_tile(bias_lds_write_window, shuffled_bias_tile); - block_sync_lds(); - auto bias_s_tile = load_tile(bias_s_lds_read_window); - tile_elementwise_inout( - [&](auto& x, const auto& y) { - x = scale * x + log2e_v * type_convert(y); - }, - s_acc, - bias_s_tile); - move_tile_window(bias_dram_window, {kM0, 0}); - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = seqlen_q_step + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + decltype(load_tile(q_lds_read_window)) q_reg_tensor; + decltype(load_tile(lse_lds_read_window)) lse; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next; + decltype(load_tile(do_lds_read_window)) do_reg_tensor; + decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor; + decltype(load_tile(d_lds_read_window)) d; + decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor; + decltype(gemm_0.MakeCBlockTile()) s_acc, p; + decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; + decltype(gemm_4.MakeCBlockTile()) dq_acc; + + decltype(load_tile(lse_dram_window)) lse_block_tile; + decltype(load_tile(d_dram_window)) d_block_tile; + + index_t i_total_bodys = 0; + auto main_body_impl = [&](auto is_prologue_, + auto is_epilogue_, + QDataType* const __restrict__ q_lds_ptr_curr, + QDataType* const __restrict__ q_lds_ptr_next, + OGradDataType* const __restrict__ do_lds_ptr_curr, + OGradDataType* const __restrict__ do_lds_ptr_next) mutable { + constexpr bool is_prologue = is_prologue_.value; + constexpr bool is_epilogue = is_epilogue_.value; + static_assert(is_prologue || is_epilogue, + "is_prologue or is_epilogue should be true"); + constexpr bool is_main_body = is_prologue && is_epilogue; + + if constexpr(is_prologue) + { + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + + q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + async_load_tile(q_lds_write_window, q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + + do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + async_load_tile(do_lds_write_window, do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + } + if constexpr(is_epilogue) + { + // STAGE 1, Q@K Gemm0 + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); + dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm0(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - s_acc(i_j_idx) *= scale; - position_encoding.update(s_acc(i_j_idx), row, col); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + else + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); }); }); - } - { - bool need_perpixel_check = mask.IsEdgeTile( - seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); - if(need_perpixel_check) - { - set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = seqlen_q_step + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - constexpr auto p_spans = decltype(p)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); - - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); - else - p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); - }); - }); - - if constexpr(FmhaDropout::IsDropout) - { - dropout.template Run( - seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); - } - const auto p_gemm = [&]() { // dropout / type conversion if constexpr(FmhaDropout::IsDropout) { - return tile_elementwise_in( - [](const auto& x) { - return type_convert(x > 0.f ? x : 0.f); - }, - p); + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); } - else - { - return cast_tile(p); - } - }(); - - // STAGE 4, OGrad@V Gemm2 - dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); - - qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - - // STAGE 3, P^T@OGrad^T Gemm1 - auto pt_reg_tensor = make_static_distributed_tensor( - Policy::template MakePTRegSliceBlockDescriptor()); - pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); - gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); - } - block_sync_lds(); - if constexpr(is_main_body) - Policy::template HotLoopScheduler::SchedulerGemm12(); - __builtin_amdgcn_sched_barrier(0); - if constexpr(is_prologue) - { - store_tile(lse_lds_write_window, lse_block_tile); - store_tile(d_lds_write_window, d_block_tile); - } - if constexpr(is_epilogue) - { - // STAGE 5, P^T(PGrad^T - D) - constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); - sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - bool undrop_flag = p[i_j_idx] >= 0; - ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag - ? (dp_acc[i_j_idx] - d[i_idx]) - : d[i_idx]); - }); - }); - - if constexpr(kHasBiasGrad) - { - const auto dbias = [&]() { + const auto p_gemm = [&]() { // dropout / type conversion if constexpr(FmhaDropout::IsDropout) { return tile_elementwise_in( - [&rp_undrop](const auto& x) { - return type_convert(x * rp_undrop); + [](const auto& x) { + return type_convert(x > 0.f ? x : 0.f); }, - ds); + p); } else { - return cast_tile(ds); + return cast_tile(p); } }(); - store_tile(bias_lds_write_window, dbias); - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); - auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); - auto dbias_tile = make_static_distributed_tensor( - Policy::template MakeBiasTileDistribution()); - shuffle_tile(dbias_tile, shuffled_dbias_tile); - store_tile(dbias_dram_window, dbias_tile); - move_tile_window(dbias_dram_window, {kM0, 0}); - __builtin_amdgcn_sched_barrier(0); - } - } - if constexpr(is_epilogue) - { - // STAGE 6, SGrad^T@Q^T Gemm3 - const auto ds_gemm = cast_tile(ds); - auto dst_reg_tensor = make_static_distributed_tensor( - Policy::template MakeSGradTRegSliceBlockDescriptor()); - dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); - gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - store_tile(ds_lds_window, ds_gemm); - } - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); - if constexpr(is_prologue) - { - q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); - q_reg_tensor = load_tile(q_lds_read_window); - lse = load_tile(lse_lds_read_window); - } - if constexpr(is_epilogue) - { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); - move_tile_window(ds_lds_read_window, {kK4, 0}); - } - if constexpr(is_main_body) - Policy::template HotLoopScheduler::SchedulerGemm3(); - __builtin_amdgcn_sched_barrier(0); - if constexpr(is_epilogue) - { - // STAGE7 SGrad@K^T Gemm4 - clear_tile(dq_acc); - static_for<0, k4_loops, 1>{}([&](auto i_k4) { - if constexpr(i_k4 < k4_loops - 1) - { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); - move_tile_window(ds_lds_read_window, {kK4, 0}); - } - auto kt_reg_tensor_slice = get_slice_tile( // - kt_reg_tensor, - sequence<0, i_k4 * kK4>{}, - sequence{}); - gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + // STAGE 4, OGrad@V Gemm2 + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); - if constexpr(i_k4 < k4_loops - 1) + qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); + qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + } + block_sync_lds(); + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm12(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_prologue) + { + store_tile(lse_lds_write_window, lse_block_tile); + store_tile(d_lds_write_window, d_block_tile); + } + if constexpr(is_epilogue) + { + // STAGE 5, P^T(PGrad^T - D) + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) { - ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); } - }); - move_tile_window(ds_lds_read_window, {-kN0, 0}); - } - block_sync_lds(); - if constexpr(is_prologue) + } + if constexpr(is_epilogue) + { + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + } + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + if constexpr(is_prologue) + { + q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + q_reg_tensor = load_tile(q_lds_read_window); + lse = load_tile(lse_lds_read_window); + } + if constexpr(is_epilogue) + { + ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm3(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE7 SGrad@K^T Gemm4 + clear_tile(dq_acc); + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + auto kt_reg_tensor_slice = get_slice_tile( // + kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = + ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {-kN0, 0}); + } + block_sync_lds(); + if constexpr(is_prologue) + { + do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + do_reg_tensor = load_tile(do_lds_read_window); + d = load_tile(d_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm4(); + if constexpr(is_epilogue) + { + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout( + [&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, + dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + } + }; + + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + const bool is_even = (i_total_bodys % 2 == 0); + const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; + main_body_impl(is_prologue_, + is_epilogue_, + q_lds_ptr_curr, + q_lds_ptr_next, + do_lds_ptr_curr, + do_lds_ptr_next); + i_total_bodys += 1; + }; + + main_body(std::true_type{}, std::false_type{}); + // Hot loop + if(num_total_loop > 1) { - do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); - do_reg_tensor = load_tile(do_lds_read_window); - d = load_tile(d_lds_read_window); + do + { + main_body(std::true_type{}, std::true_type{}); + i_total_loops += 1; + seqlen_q_step += kM0; + } while(i_total_loops < num_total_loop - 1); } - if constexpr(is_main_body) - Policy::template HotLoopScheduler::SchedulerGemm4(); - if constexpr(is_epilogue) + main_body(std::false_type{}, std::true_type{}); + + // Results Scale + if constexpr(FmhaDropout::IsDropout) { - // QGrad Scale - if constexpr(FmhaDropout::IsDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dq_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); - } - if constexpr(kIsDeterministic) - { - store_tile(dq_dram_window, dq_acc); - } - else - { - update_tile(dq_dram_window, dq_acc); - } - move_tile_window(dq_dram_window, {kM0, 0}); + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); } }; - - auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { - const bool is_even = (i_total_bodys % 2 == 0); - const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; - const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; - const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; - const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; - main_body_impl(is_prologue_, - is_epilogue_, - q_lds_ptr_curr, - q_lds_ptr_next, - do_lds_ptr_curr, - do_lds_ptr_next); - i_total_bodys += 1; - }; - - main_body(std::true_type{}, std::false_type{}); - // Hot loop - if(num_total_loop > 1) - { - do - { - main_body(std::true_type{}, std::true_type{}); - i_total_loops += 1; - seqlen_q_step += kM0; - } while(i_total_loops < num_total_loop - 1); - } - main_body(std::false_type{}, std::true_type{}); - - // Results Scale - if constexpr(FmhaDropout::IsDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dk_acc); - tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); - } - - }; - restrict_body(reinterpret_cast(smem_ptr_), // k_lds_ptr - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK()), // v_lds_ptr - reinterpret_cast(smem_ptr_), // do_lds_ptr0 - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeOGrad()), // do_lds_ptr1 - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeOGrad() - + Policy::template GetSmemSizeOGrad()), // q_lds_ptr0 - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeOGrad() - + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeQ()), // q_lds_ptr1 - reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()), // lse_lds_ptr - reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE()), // d_lds_ptr - reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()), // ds_ltr_ptr - reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD())); // bias_ltr_ptr + restrict_body( + reinterpret_cast(smem_ptr_), // k_lds_ptr + reinterpret_cast(smem_ptr_ + + Policy::template GetSmemSizeK()), // v_lds_ptr + reinterpret_cast(smem_ptr_), // do_lds_ptr0 + reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad()), // do_lds_ptr1 + reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad()), // q_lds_ptr0 + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()), // q_lds_ptr1 + reinterpret_cast(smem_ptr_ + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ()), // lse_lds_ptr + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE()), // d_lds_ptr + reinterpret_cast(smem_ptr_ + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD()), // ds_ltr_ptr + reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); // bias_ltr_ptr return make_tuple(dk_acc, dv_acc); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 9ae1d25c97..73901d9bb0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_barrier(0); - auto mainloop = [&] (index_t cur_loop, + auto mainloop = [&](index_t cur_loop, KDataType* __restrict__ k_lds_write_ptr, KDataType* __restrict__ k_lds_read_ptr, KDataType* __restrict__ v_lds_write_ptr, KDataType* __restrict__ v_lds_read_ptr) { - // move V tile windows block_sync_lds(); move_tile_window(v_dram_window, {kN0, 0}); @@ -1108,7 +1107,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload do { - bool is_even_loop = i_total_loops % 2 == 0; + bool is_even_loop = i_total_loops % 2 == 0; auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) : static_cast(smem_ptrk1); auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) @@ -1117,7 +1116,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload : static_cast(smem_ptrv0); auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) : static_cast(smem_ptrv1); - mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr); + mainloop( + i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr); i_total_loops++; } while(i_total_loops < num_total_loop);