mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Use NRepetitions2DEpilogue for outputing o_acc tile
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
@@ -59,33 +60,35 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH_3(
|
||||
pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
BOOL_SWITCH_3(pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
using HstuEpilogue =
|
||||
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
using HstuPipeline =
|
||||
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
|
||||
118
example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp
Normal file
118
example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp
Normal file
@@ -0,0 +1,118 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct NRepetitions2DEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
index_t NumNRepetition,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
number<NumNRepetition>)
|
||||
{
|
||||
constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(kN % NumNRepetition == 0, "Check failed!");
|
||||
|
||||
constexpr index_t kSingleRepN = kN / NumNRepetition;
|
||||
|
||||
auto o_nrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM>{}, number<kSingleRepN>{}),
|
||||
o_dram_window_tmp.get_window_origin());
|
||||
|
||||
static_for<0, NumNRepetition, 1>{}([&](auto i_rep) {
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
auto tile_for_store =
|
||||
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
|
||||
sequence<0, i_rep * kSingleRepN>{},
|
||||
sequence<kM, (i_rep + 1) * kSingleRepN>{}));
|
||||
store_tile(o_nrep_dram_window, tile_for_store);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto tile_for_store =
|
||||
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
|
||||
sequence<0, i_rep * kSingleRepN>{},
|
||||
sequence<kM, (i_rep + 1) * kSingleRepN>{}));
|
||||
update_tile(o_nrep_dram_window, tile_for_store);
|
||||
}
|
||||
|
||||
move_tile_window(o_nrep_dram_window, {0, kSingleRepN});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct MRepetitions2DEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
index_t NumMRepetition,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
number<NumMRepetition>)
|
||||
{
|
||||
constexpr index_t kM = ODramWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t kN = ODramWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(kM % NumMRepetition == 0, "Check failed!");
|
||||
|
||||
constexpr index_t kSingleRepM = kM / NumMRepetition;
|
||||
|
||||
auto o_mrep_dram_window = make_tile_window(o_dram_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kSingleRepM>{}, number<kN>{}),
|
||||
o_dram_window_tmp.get_window_origin());
|
||||
|
||||
static_for<0, NumMRepetition, 1>{}([&](auto i_rep) {
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
auto tile_for_store =
|
||||
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
|
||||
sequence<i_rep * kSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kSingleRepM, kN>{}));
|
||||
store_tile(o_mrep_dram_window, tile_for_store);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto tile_for_store =
|
||||
cast_tile<ODataType>(get_slice_tile(o_acc_tile,
|
||||
sequence<i_rep * kSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kSingleRepM, kN>{}));
|
||||
store_tile(o_mrep_dram_window, tile_for_store);
|
||||
}
|
||||
|
||||
move_tile_window(o_mrep_dram_window, {kSingleRepM, 0});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -765,7 +765,9 @@ struct HstuAttentionFwdKernel
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
constexpr index_t NumRepN =
|
||||
HstuAttentionPipeline::kN1 / HstuAttentionPipeline::kGemm1SingleRepN;
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, number<NumRepN>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -71,6 +71,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
|
||||
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
|
||||
@@ -549,6 +549,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN()
|
||||
{
|
||||
return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) *
|
||||
Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
@@ -69,7 +70,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
|
||||
Reference in New Issue
Block a user