Use NRepetitions2DEpilogue for outputing o_acc tile

This commit is contained in:
Qianfeng Zhang
2025-05-26 14:09:55 +00:00
parent 81f7b139e0
commit dc0977faad
6 changed files with 160 additions and 25 deletions

View File

@@ -22,6 +22,7 @@
#include "hstu_attention_traits.hpp"
#include "hstu_attention_fwd_pipeline.hpp"
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
template <typename InOutDataType,
bool kUseCausal,
@@ -59,33 +60,35 @@ struct batched_forward_causal_local_bias_dropout_dispatch
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
BOOL_SWITCH_3(
pad_seqlen_k,
kPadSeqLenK,
pad_headdim_qk,
kPadHeadDimQK,
pad_headdim_v,
kPadHeadDimV,
[&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
BOOL_SWITCH_3(pad_seqlen_k,
kPadSeqLenK,
pad_headdim_qk,
kPadHeadDimQK,
pad_headdim_v,
kPadHeadDimV,
[&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
kPadSeqLenQ,
kPadHeadDimV>>;
using HstuEpilogue =
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
kPadSeqLenQ,
kPadHeadDimV>>;
using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
using HstuPipeline =
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
using HstuKernel =
ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
RunWithKernel<HstuKernel>(param, stream);
});
RunWithKernel<HstuKernel>(param, stream);
});
};
template <typename HstuKernel>

View File

@@ -0,0 +1,118 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct NRepetitions2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename ODramWindowTmp,
typename OAccTile,
index_t NumNRepetition,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
number<NumNRepetition>)
{
constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(kN % NumNRepetition == 0, "Check failed!");
constexpr index_t kSingleRepN = kN / NumNRepetition;
auto o_nrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM>{}, number<kSingleRepN>{}),
o_dram_window_tmp.get_window_origin());
static_for<0, NumNRepetition, 1>{}([&](auto i_rep) {
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
auto tile_for_store =
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
sequence<0, i_rep * kSingleRepN>{},
sequence<kM, (i_rep + 1) * kSingleRepN>{}));
store_tile(o_nrep_dram_window, tile_for_store);
}
else
{
auto tile_for_store =
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
sequence<0, i_rep * kSingleRepN>{},
sequence<kM, (i_rep + 1) * kSingleRepN>{}));
update_tile(o_nrep_dram_window, tile_for_store);
}
move_tile_window(o_nrep_dram_window, {0, kSingleRepN});
});
}
};
template <typename Problem_, typename Policy_ = void>
struct MRepetitions2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename ODramWindowTmp,
typename OAccTile,
index_t NumMRepetition,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
number<NumMRepetition>)
{
constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(kM % NumMRepetition == 0, "Check failed!");
constexpr index_t kSingleRepM = kM / NumMRepetition;
auto o_mrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kSingleRepM>{}, number<kN>{}),
o_dram_window_tmp.get_window_origin());
static_for<0, NumMRepetition, 1>{}([&](auto i_rep) {
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
auto tile_for_store =
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
sequence<i_rep * kSingleRepM, 0>{},
sequence<(i_rep + 1) * kSingleRepM, kN>{}));
store_tile(o_mrep_dram_window, tile_for_store);
}
else
{
auto tile_for_store =
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
sequence<i_rep * kSingleRepM, 0>{},
sequence<(i_rep + 1) * kSingleRepM, kN>{}));
store_tile(o_mrep_dram_window, tile_for_store);
}
move_tile_window(o_mrep_dram_window, {kSingleRepM, 0});
});
}
};
} // namespace ck_tile

View File

@@ -765,7 +765,9 @@ struct HstuAttentionFwdKernel
make_tuple(number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
constexpr index_t NumRepN =
HstuAttentionPipeline::kN1 / HstuAttentionPipeline::kGemm1SingleRepN;
EpiloguePipeline{}(o_dram_window, o_acc_tile, number<NumRepN>{});
}
};

View File

@@ -71,6 +71,10 @@ struct HstuAttentionFwdPipelineQRKSVS
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
// used by NRepetitions2DEpilogue
static constexpr index_t kGemm1SingleRepN =
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;

View File

@@ -549,6 +549,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN()
{
return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) *
Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{

View File

@@ -22,6 +22,7 @@
#include "hstu_attention_traits.hpp"
#include "hstu_attention_fwd_pipeline.hpp"
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
template <typename InOutDataType,
bool kUseCausal,
@@ -69,7 +70,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
kPadSeqLenQ,