mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
CK Tile FA Training kernels (#1286)
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
|
||||
struct BlockFmhaBwdOGradDotO
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kVHeaddim = Problem::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
float p_undrop) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize ==
|
||||
OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
o_dram_block_window_tmp.get_window_lengths(),
|
||||
o_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePreODramTileDistribution<Problem>());
|
||||
|
||||
auto o = load_tile(o_dram_window);
|
||||
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
do_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePreOGradDramTileDistribution<Problem>());
|
||||
|
||||
auto do_ = load_tile(do_dram_window);
|
||||
|
||||
// declare d
|
||||
constexpr auto d_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
|
||||
|
||||
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
|
||||
|
||||
clear_tile(d); // Initialize D
|
||||
|
||||
constexpr auto o_spans = decltype(o)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
d(i_idx) +=
|
||||
(type_convert<DDataType>(o[i_j_idx]) * type_convert<DDataType>(do_[i_j_idx]));
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
|
||||
|
||||
store_tile(d_dram_block_window_tmp, d);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// These templates are not used here.
|
||||
using BlockFmhaBwdOGradDotODefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ false,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ false,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,848 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = true;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_kts_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType,
|
||||
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto kt_dram_block_window = kt_dram_block_window_tmp;
|
||||
|
||||
auto kt_dram_window = make_tile_window(
|
||||
kt_dram_block_window.get_bottom_tensor_view(),
|
||||
kt_dram_block_window.get_window_lengths(),
|
||||
kt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
|
||||
// load
|
||||
|
||||
auto kt_block_tile = load_tile(kt_dram_window);
|
||||
|
||||
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
|
||||
|
||||
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k & k^t located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ true,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,821 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,692 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = true;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "qs_ks_vr_dos";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
}
|
||||
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
block_sync_lds();
|
||||
store_tile(do_lds_window, do_block_tile); // store the prefetch
|
||||
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
get_slice_tile(dot_lds_window,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
get_slice_tile(do_lds_window,
|
||||
sequence<0, i_k2 * kK2>{},
|
||||
sequence<kM0, (i_k2 + 1) * kK2>{}),
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
block_sync_lds();
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
get_slice_tile(qt_lds_window,
|
||||
sequence<0, i_k3 * kK3>{},
|
||||
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, q & k & do located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ true,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KSKTSVR = 0,
|
||||
QSKSVROGradS,
|
||||
KSVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename GemmDataType_,
|
||||
typename LSEDataType_,
|
||||
typename AccDataType_,
|
||||
typename DDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename QGradDataType_,
|
||||
typename KGradDataType_,
|
||||
typename VGradDataType_,
|
||||
typename BiasGradDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using GemmDataType = remove_cvref_t<GemmDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using QGradDataType = remove_cvref_t<QGradDataType_>;
|
||||
using KGradDataType = remove_cvref_t<KGradDataType_>;
|
||||
using VGradDataType = remove_cvref_t<VGradDataType_>;
|
||||
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename DDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kVHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
{
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
"kBlockSize should be divisible by get_warp_size()");
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kVHeaddim = kVHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,6 +13,7 @@ template <typename QDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
@@ -23,19 +24,20 @@ template <typename QDataType_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -14,19 +15,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
@@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
@@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
@@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -15,19 +16,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -118,6 +121,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
@@ -137,6 +141,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -145,7 +150,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -292,6 +298,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -558,6 +567,17 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
@@ -688,6 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
@@ -695,11 +716,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -709,6 +732,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
@@ -717,7 +741,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,19 +14,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -106,20 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported
|
||||
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
|
||||
FmhaMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
float descale_qk,
|
||||
float descale_sv,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& /*dropout*/) const // not supported
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -13,19 +13,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
// TODO: assume Q is in register
|
||||
// TODO: assume K/V has same data type
|
||||
@@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
single_smem_size * max(NumPrefetchK, NumPrefetchV);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(AsyncCopyK)
|
||||
{
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
|
||||
{
|
||||
if constexpr(Problem::kHasDropout)
|
||||
{
|
||||
constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto config =
|
||||
decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = WG::kN;
|
||||
|
||||
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
|
||||
@@ -43,4 +43,53 @@ struct TileFmhaShape
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
typename Gemm2BlockWarps_,
|
||||
typename Gemm2WarpTile_,
|
||||
typename Gemm3BlockWarps_,
|
||||
typename Gemm3WarpTile_,
|
||||
typename Gemm4BlockWarps_,
|
||||
typename Gemm4WarpTile_>
|
||||
struct TileFmhaBwdShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
using Gemm2BlockWarps = remove_cvref_t<Gemm2BlockWarps_>;
|
||||
using Gemm2WarpTile = remove_cvref_t<Gemm2WarpTile_>;
|
||||
using Gemm3BlockWarps = remove_cvref_t<Gemm3BlockWarps_>;
|
||||
using Gemm3WarpTile = remove_cvref_t<Gemm3WarpTile_>;
|
||||
using Gemm4BlockWarps = remove_cvref_t<Gemm4BlockWarps_>;
|
||||
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 =
|
||||
BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
|
||||
static constexpr index_t kK1 =
|
||||
BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
|
||||
static constexpr index_t kK2 =
|
||||
BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
|
||||
static constexpr index_t kK3 =
|
||||
BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
|
||||
static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
|
||||
// K/K^T at once
|
||||
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
|
||||
// that need load V at once
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -13,7 +13,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaTraits
|
||||
@@ -23,9 +25,21 @@ struct TileFmhaTraits
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
struct TileFmhaBwdOGradDotOTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user