From a079b95b7787ef1e2c599cac862a00cf65e07995 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 26 May 2025 14:09:55 +0000 Subject: [PATCH] Use NRepetitions2DEpilogue for outputing o_acc tile --- ...stu_attention_batched_forward_dispatch.hpp | 49 ++++---- .../hstu_attention_epilogue.hpp | 118 ++++++++++++++++++ .../hstu_attention_fwd_kernel.hpp | 4 +- .../hstu_attention_fwd_pipeline.hpp | 4 + ..._attention_fwd_pipeline_default_policy.hpp | 7 ++ ...hstu_attention_jagged_forward_dispatch.hpp | 3 +- 6 files changed, 160 insertions(+), 25 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 25e15eb458..5e59736e8d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -22,6 +22,7 @@ #include "hstu_attention_traits.hpp" #include "hstu_attention_fwd_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" +#include "hstu_attention_epilogue.hpp" template ; + BOOL_SWITCH_3(pad_seqlen_k, + kPadSeqLenK, + pad_headdim_qk, + kPadHeadDimQK, + pad_headdim_v, + kPadHeadDimV, + [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; - using HstuPipelineProblem = HstuPipelineProblemTemp; + using HstuPipelineProblem = HstuPipelineProblemTemp; - using HstuEpilogue = ck_tile::Default2DEpilogue::OaccDataType, - typename HstuAttentionFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDimV>>; + using HstuEpilogue = + ck_tile::NRepetitions2DEpilogue::OaccDataType, + typename HstuAttentionFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; - using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS; - using HstuKernel = ck_tile::HstuAttentionFwdKernel; + using HstuPipeline = + ck_tile::HstuAttentionFwdPipelineQRKSVS; + using HstuKernel = + ck_tile::HstuAttentionFwdKernel; - RunWithKernel(param, stream); - }); + RunWithKernel(param, stream); + }); }; template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp new file mode 100644 index 0000000000..76daebc455 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +namespace ck_tile { + +template +struct NRepetitions2DEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const OAccTile& o_acc_tile, + number) + { + constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(kN % NumNRepetition == 0, "Check failed!"); + + constexpr index_t kSingleRepN = kN / NumNRepetition; + + auto o_nrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + o_dram_window_tmp.get_window_origin()); + + static_for<0, NumNRepetition, 1>{}([&](auto i_rep) { + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + auto tile_for_store = + cast_tile(get_slice_tile(o_acc_tile, + sequence<0, i_rep * kSingleRepN>{}, + sequence{})); + store_tile(o_nrep_dram_window, tile_for_store); + } + else + { + auto tile_for_store = + cast_tile(get_slice_tile(o_acc_tile, + sequence<0, i_rep * kSingleRepN>{}, + sequence{})); + update_tile(o_nrep_dram_window, tile_for_store); + } + + move_tile_window(o_nrep_dram_window, {0, kSingleRepN}); + }); + } +}; + +template +struct MRepetitions2DEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const OAccTile& o_acc_tile, + number) + { + constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(kM % NumMRepetition == 0, "Check failed!"); + + constexpr index_t kSingleRepM = kM / NumMRepetition; + + auto o_mrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + o_dram_window_tmp.get_window_origin()); + + static_for<0, NumMRepetition, 1>{}([&](auto i_rep) { + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + auto tile_for_store = + cast_tile(get_slice_tile(o_acc_tile, + sequence{}, + sequence<(i_rep + 1) * kSingleRepM, kN>{})); + store_tile(o_mrep_dram_window, tile_for_store); + } + else + { + auto tile_for_store = + cast_tile(get_slice_tile(o_acc_tile, + sequence{}, + sequence<(i_rep + 1) * kSingleRepM, kN>{})); + store_tile(o_mrep_dram_window, tile_for_store); + } + + move_tile_window(o_mrep_dram_window, {kSingleRepM, 0}); + }); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index ac1462ae47..47c404742e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -765,7 +765,9 @@ struct HstuAttentionFwdKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + constexpr index_t NumRepN = + HstuAttentionPipeline::kN1 / HstuAttentionPipeline::kGemm1SingleRepN; + EpiloguePipeline{}(o_dram_window, o_acc_tile, number{}); } }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 950e5323f5..0b9ffec66b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -71,6 +71,10 @@ struct HstuAttentionFwdPipelineQRKSVS static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + // used by NRepetitions2DEpilogue + static constexpr index_t kGemm1SingleRepN = + Policy::template GetKVBlockGemmSingleRepN(); + static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::Traits::kBlockPerCu != -1) return Problem::Traits::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 489b377dd2..2c4d3d50b1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -549,6 +549,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN() + { + return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) * + Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + }; + template CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index b97d1b0977..9a093b7663 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -22,6 +22,7 @@ #include "hstu_attention_traits.hpp" #include "hstu_attention_fwd_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" +#include "hstu_attention_epilogue.hpp" template ; - using HstuEpilogue = ck_tile::Default2DEpilogue::OaccDataType, typename HstuAttentionFwdTypeConfig::ODataType, kPadSeqLenQ,