From dcc3593fe40554d6de8049a3cb7e0437c3f3fc2e Mon Sep 17 00:00:00 2001 From: danyao12 Date: Thu, 25 Jul 2024 16:16:30 +0800 Subject: [PATCH] fix hd32 error and boost performance --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 2 +- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 17 ++- ...block_fmha_bwd_pipeline_default_policy.hpp | 110 ++++++++++++++++++ 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index d4e778877e..dcbc3f61bf 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - # '32' : [FmhaBwdDQDKDVTileSize( 64, 64, 32, 64, 32, 64, 64, 32, 32, 1, 2, 1, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, 1), + # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), # "kr_ktr_vr"], '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1), "kr_ktr_vr"], diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index a0e396a5ea..def5b8a013 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }(); // STAGE 3, P^T@OGrad^T Gemm1 - pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer(); + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, pt_gemm); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); auto qt_reg_tensor = load_tile(qt_lds_read_window); @@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // STAGE 6, SGrad^T@Q^T Gemm3 const auto dst_gemm = cast_tile(dst); - dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer(); + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, dst_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); @@ -908,8 +912,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR } }(); - pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer(); - auto dot_reg_tensor = load_tile(dot_lds_read_window); + Policy::template PTFromGemm0CToGemm1A( + pt_reg_tensor, pt_gemm); + auto dot_reg_tensor = load_tile(dot_lds_read_window); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<1>(); @@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // STAGE 6, SGrad^T@Q^T Gemm3 const auto dst_gemm = cast_tile(dst); - dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer(); + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, dst_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); store_tile(ds_lds_window, dst_gemm); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index cde1aa97a1..6387db4ef8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy return ds_block_dstr; } + template + CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out, + const PTInTensor& pt_in) + { + if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16) + { + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + auto pt_warp_tensor = + make_static_distributed_tensor(CWarpDstr{}); + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + pt_warp_tensor.get_thread_buffer() = pt_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + pt_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + pt_warp_tensor.get_thread_buffer()); + }); + }); + } + else + { + pt_out.get_thread_buffer() = pt_in.get_thread_buffer(); + } + } + + template + CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out, + const SGradTInTensor& dst_in) + { + if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16) + { + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + auto dst_warp_tensor = + make_static_distributed_tensor(CWarpDstr{}); + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + dst_warp_tensor.get_thread_buffer() = dst_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + dst_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + dst_warp_tensor.get_thread_buffer()); + }); + }); + } + else + { + dst_out.get_thread_buffer() = dst_in.get_thread_buffer(); + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() {