diff --git a/CMakeLists.txt b/CMakeLists.txt index 19c036e1a5..2d8bbd4019 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -572,8 +572,8 @@ include_directories(BEFORE SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) + # add_compile_options(-Werror) + # add_compile_options(-Weverything) endif() message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index bd03aee924..3105719c91 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -35,7 +35,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 - --optdim 32,64,128,256 + --optdim 32,64,128,256,384,512 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) @@ -109,7 +109,7 @@ if(FMHA_FWD_FAST_EXP2) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() -list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero -fbracket-depth=512) # conditionally enable call to the fwd_splitkv API in fmha_fwd example if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 9e15a822ef..e0e537409f 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -12,23 +12,23 @@ FWD_DTYPE_MAP = { BWD_DTYPE_MAP = { "fp16": "FmhaBwdFp16", - "bf16": "FmhaBwdBf16" + # "bf16": "FmhaBwdBf16" } MASK_IMPL = { "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" + # "simplified" : "ck_tile::SimplifiedGenericAttentionMask" } _MASK_SIMPLIFIED_MAP = { "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", + # "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", } _MASK_MAP = { "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" + # "causal" : "FmhaMasks::CausalMask", + # "generic" : "FmhaMasks::GenericMask" } def get_mask_map(mask : str): @@ -62,8 +62,8 @@ def get_mask_check_map(mask : str): BIAS_MAP = { "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" + # "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + # "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" } # TODO: this is ugly @@ -75,10 +75,10 @@ BIAS_CHECK_MAP = { DROPOUT_MAP = { "no" : "ck_tile::BlockDropoutBwd", - "dropout_wg32" : "ck_tile::BlockDropoutBwd", - "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", - "dropout_wg16" : "ck_tile::BlockDropoutBwd", - "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" + # "dropout_wg32" : "ck_tile::BlockDropoutBwd", + # "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + # "dropout_wg16" : "ck_tile::BlockDropoutBwd", + # "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" } DROPOUT_CHECK_MAP = { @@ -103,7 +103,7 @@ ROPE_CHECK_MAP = { MODE_MAP = { "batch" : "false", - "group" : "true" + # "group" : "true" } LAYOUT_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 77b63a0c83..b2e657b4cc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -351,15 +351,17 @@ class FmhaBwdDQDKDVKernel: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), # "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '512' : [FmhaBwdDQDKDVTileSize( 16, 64, 512, 16, 512, 16, 32, 512, 512, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: @@ -748,6 +750,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): continue if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index c88b058d32..3e8e1fc3ba 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -94,24 +94,1098 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP typename BiasGradDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - const KDramBlockWindowTmp& k_dram_block_window_tmp, - const VDramBlockWindowTmp& v_dram_block_window_tmp, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, - const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - const OGradDramBlockWindowTmp& do_dram_block_window_tmp, - const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, - const DDramBlockWindowTmp& d_dram_block_window_tmp, - const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, - const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, - FmhaMask mask, - PositionEncoding position_encoding, - float raw_scale, - float scale, - float rp_undrop, - float scale_rp_undrop, - void* smem_ptr, - FmhaDropout& dropout) const + dq(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return; + } + } + KDataType* k_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + + auto k_reg_tensor = make_static_distributed_tensor( + Policy::template MakeKRegBlockDescriptor()); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + + VDataType* v_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + + //------------------------------------------------------------------ + // KT, Reg ->LDS ->Reg + // KT reuse K LDS buffer + auto shuffled_k_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledKRegWriteBlockDescriptor()); + + KDataType* kt_lds_ptr = k_lds_ptr; + + auto shuffled_k_lds_write = make_tensor_view( + kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); + + auto shuffled_k_lds_write_window = make_tile_window( + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto kt_lds_read = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); + + auto kt_lds_read_window = + make_tile_window(kt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_block_tile = load_tile(k_dram_window); + store_tile(k_lds_write_window, k_block_tile); + block_sync_lds(); + k_reg_tensor = load_tile(k_lds_read_window); + + shuffle_tile(shuffled_k_block_tile, k_block_tile); + block_sync_lds(); + store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile); + block_sync_lds(); + auto kt_reg_tensor = load_tile(kt_lds_read_window); + + __builtin_amdgcn_sched_barrier(0); + + auto v_block_tile = load_tile(v_dram_window); + block_sync_lds(); + store_tile(v_lds_write_window, v_block_tile); + block_sync_lds(); + auto v_reg_tensor = load_tile(v_lds_read_window); + + //---------------------------- Loop Load in ----------------------------// + // Q: HBM ->Reg ->LDS + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr))); + + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_lds_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + + // dO: HBM ->Reg ->LDS + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read_window = + make_tile_window(do_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + do_lds_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD() + + Policy::template GetSmemSizeBias())); + + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + auto ds_lds_read_window = + make_tile_window(ds_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + ds_lds_window.get_window_origin(), + Policy::template MakeSGradRegSliceBlockDescriptor()); + + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + BiasDataType* bias_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); + + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_s_lds_read_window = + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + LSEDataType* lse_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + DDataType* d_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeLSE())); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + __builtin_amdgcn_sched_barrier(0); + + // Hot loop + while(i_total_loops < num_total_loop) + { + block_sync_lds(); + + // STAGE 1, Q@K Gemm0 + auto s_acc = SPBlockTileType{}; + auto q_block_tile = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + auto q_reg_tensor = load_tile(q_lds_read_window); + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + __builtin_amdgcn_sched_barrier(0); + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + auto lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + store_tile(lse_lds_write_window, lse_block_tile); + block_sync_lds(); + auto lse = load_tile(lse_lds_read_window); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto p = SPBlockTileType{}; + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + } + else + { + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + } + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 4, OGrad@V Gemm2 + auto do_block_tile = load_tile(do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + auto do_reg_tensor = load_tile(do_lds_read_window); + auto dp_acc = SPGradBlockTileType{}; + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + __builtin_amdgcn_sched_barrier(0); + + // STAGE 5, P^T(PGrad^T - D) + auto d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + store_tile(d_lds_write_window, d_block_tile); + block_sync_lds(); + auto d = load_tile(d_lds_read_window); + auto ds = SPGradBlockTileType{}; + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + // STAGE7 SGrad@K^T Gemm4 + const auto ds_gemm = cast_tile(ds); + store_tile(ds_lds_window, ds_gemm); + block_sync_lds(); + + auto ds_reg_tensor = load_tile(ds_lds_read_window); + auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; + move_tile_window(ds_lds_read_window, {0, kK4}); + __builtin_amdgcn_sched_barrier(0); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {0, kK4}); + } + auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {0, -kN0}); + + // QGrad scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + + // update + i_total_loops += 1; + seqlen_q_step += kM0; + } + } + + template + CK_TILE_HOST_DEVICE auto + dk_dv(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return make_tuple(dk_acc, dv_acc); + } + } + KDataType* k_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + + auto k_reg_tensor = make_static_distributed_tensor( + Policy::template MakeKRegBlockDescriptor()); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + + VDataType* v_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_block_tile = load_tile(k_dram_window); + auto v_block_tile = load_tile(v_dram_window); + store_tile(k_lds_write_window, k_block_tile); + block_sync_lds(); + k_reg_tensor = load_tile(k_lds_read_window); + block_sync_lds(); + store_tile(v_lds_write_window, v_block_tile); + block_sync_lds(); + auto v_reg_tensor = load_tile(v_lds_read_window); + + //---------------------------- Loop Load in ----------------------------// + // Q: HBM ->Reg ->LDS + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr))); + + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_lds_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + + // QT: Reg -> Reg-> LDS + QDataType* qt_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr) + + Policy::template GetSmemSizeQ())); + + auto shuffled_q_lds_write = make_tensor_view( + qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); + + auto shuffled_q_lds_write_window = make_tile_window( + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto qt_lds_read = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); + + auto qt_lds_read_window = + make_tile_window(qt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->Reg ->LDS + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQT())); + + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read_window = + make_tile_window(do_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + do_lds_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + // dOT: Reg ->Reg ->LDS + // dot reuse do lds memory + OGradDataType* dot_lds_ptr = static_cast(do_lds_ptr); + + auto shuffled_do_lds_write = make_tensor_view( + dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); + + auto shuffled_do_lds_write_window = make_tile_window( + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto dot_read_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); + + auto dot_lds_read_window = + make_tile_window(dot_read_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + BiasDataType* bias_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); + + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_s_lds_read_window = + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + LSEDataType* lse_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad())); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + DDataType* d_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeLSE())); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + + // Hot loop + while(i_total_loops < num_total_loop) + { + // STAGE 4, OGrad@V Gemm2 + auto do_block_tile = load_tile(do_dram_window); + store_tile(do_lds_window, do_block_tile); + move_tile_window(do_dram_window, {kM0, 0}); + block_sync_lds(); + auto do_reg_tensor = load_tile(do_lds_read_window); + auto dp_acc = SPGradBlockTileType{}; + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + __builtin_amdgcn_sched_barrier(0); + + auto shuffled_do_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledOGradRegWriteBlockDescriptor()); + shuffle_tile(shuffled_do_block_tile, do_block_tile); + block_sync_lds(); + store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile); + + // __builtin_amdgcn_sched_barrier(0); // TODO(need opt) + + // STAGE 1, Q@K Gemm0 + auto s_acc = SPBlockTileType{}; + auto q_block_tile = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + store_tile(q_lds_window, q_block_tile); + + auto shuffled_q_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledQRegWriteBlockDescriptor()); + shuffle_tile(shuffled_q_block_tile, q_block_tile); + block_sync_lds(); + store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile); + auto q_reg_tensor = load_tile(q_lds_read_window); + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + __builtin_amdgcn_sched_barrier(0); + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + auto lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + store_tile(lse_lds_write_window, lse_block_tile); + // block_sync_lds(); + // auto lse = load_tile(lse_lds_read_window); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + if constexpr(BiasEnum != BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + } + auto lse = load_tile(lse_lds_read_window); + auto p = SPBlockTileType{}; + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + } + else + { + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + } + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 5, P^T(PGrad^T - D) + auto d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + store_tile(d_lds_write_window, d_block_tile); + block_sync_lds(); + auto d = load_tile(d_lds_read_window); + auto ds = SPGradBlockTileType{}; + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + auto qt_reg_tensor = load_tile(qt_lds_read_window); + + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + __builtin_amdgcn_sched_barrier(0); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto dot_reg_tensor = load_tile(dot_lds_read_window); + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + + // update + i_total_loops += 1; + seqlen_q_step += kM0; + } + + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + return make_tuple(dk_acc, dv_acc); + } + + template + CK_TILE_HOST_DEVICE auto + fused_dqdkdv(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const { static_assert( std::is_same_v> && @@ -1046,6 +2120,87 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return make_tuple(dk_acc, dv_acc); } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + if constexpr(kQKHeaddim > 256 && kVHeaddim > 256) + { + dq(q_dram_block_window_tmp, k_dram_block_window_tmp, v_dram_block_window_tmp, + bias_dram_block_window_tmp, randval_dram_block_window_tmp, do_dram_block_window_tmp, + lse_dram_block_window_tmp, d_dram_block_window_tmp, dq_dram_block_window_tmp, + dbias_dram_block_window_tmp, mask, position_encoding, raw_scale, scale, rp_undrop, + scale_rp_undrop, smem_ptr, dropout); + + block_sync_lds(); + + return dk_dv(q_dram_block_window_tmp, k_dram_block_window_tmp, v_dram_block_window_tmp, + bias_dram_block_window_tmp, randval_dram_block_window_tmp, do_dram_block_window_tmp, + lse_dram_block_window_tmp, d_dram_block_window_tmp, dq_dram_block_window_tmp, + dbias_dram_block_window_tmp, mask, position_encoding, raw_scale, scale, rp_undrop, + scale_rp_undrop, smem_ptr, dropout); + } + else + { + return fused_dqdkdv(q_dram_block_window_tmp, k_dram_block_window_tmp, v_dram_block_window_tmp, + bias_dram_block_window_tmp, randval_dram_block_window_tmp, do_dram_block_window_tmp, + lse_dram_block_window_tmp, d_dram_block_window_tmp, dq_dram_block_window_tmp, + dbias_dram_block_window_tmp, mask, position_encoding, raw_scale, scale, rp_undrop, + scale_rp_undrop, smem_ptr, dropout); + } + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 521968a43b..48d6cc67c6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -712,7 +712,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kKPerBlock = Problem::kQKHeaddim; - constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K1 = 32 / sizeof(AccDataType); constexpr index_t K0 = kKPerBlock / K1; constexpr index_t M2 = get_warp_size() / K0; @@ -1930,13 +1930,44 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_ds = GetSmemSizeSGrad(); constexpr index_t smem_size_bias = GetSmemSizeBias(); - constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; - constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + - smem_size_do + smem_size_lse + smem_size_d + - max(smem_size_bias, smem_size_ds); + if constexpr (Problem::BlockFmhaShape::kQKHeaddim > 256 && Problem::BlockFmhaShape::kVHeaddim > 256) + { + // kernel0: dq + // LDS layout + // | leading stage | leading stage | loop stage + // | K(K^T) | V | Q + // | | | OGrad + // | | | LSE + // | | | D + // | | | Bias + // | | | SGrad + // kernel1: dk & dv + // LDS layout + // | leading stage | leading stage | loop stage + // | K | V | Q + // | | | Q^T + // | | | OGrad(OGrad^T) + // | | | LSE + // | | | D + // | | | Bias + // + // Note: + // A(B) mean A and B use same LDS - return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1); + constexpr index_t smem_size_kernel0 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias + smem_size_ds); + constexpr index_t smem_size_kernel1 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_qt + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias); + return max(smem_size_kernel0, smem_size_kernel1); + } + else + { + constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; + constexpr index_t smem_size_stage0_1 = smem_size_v; + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + smem_size_do + smem_size_lse + smem_size_d + + max(smem_size_bias, smem_size_ds); + + return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1); + } } template