mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -26,14 +27,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
|
||||
template <typename FmhaPipeline_,
|
||||
typename KGradEpiloguePipeline_,
|
||||
typename VGradEpiloguePipeline_,
|
||||
typename QGradEpiloguePipeline_ = void>
|
||||
struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
|
||||
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
|
||||
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static constexpr bool kUseQrQtrDorPipeline =
|
||||
ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c<FmhaPipeline>;
|
||||
static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
|
||||
"QrQtrDorPipeline needs QGradEpiloguePipeline");
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
@@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
|
||||
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
@@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
const void* d_ptr;
|
||||
void* dq_acc_ptr;
|
||||
void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
|
||||
@@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
@@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
return dim3(
|
||||
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
|
||||
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
@@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
|
||||
batch_offset_dk;
|
||||
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
|
||||
batch_offset_dv;
|
||||
auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
|
||||
auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
|
||||
|
||||
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
|
||||
|
||||
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kUseKSplit)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
@@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
|
||||
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
@@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
auto dk_dram = [&]() {
|
||||
const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dk_ptr,
|
||||
@@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel
|
||||
dv_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
{i_n0, 0});
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
{
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dk_dram_window,
|
||||
dv_dram_window,
|
||||
dbias_dram_window,
|
||||
QGradEpiloguePipeline{},
|
||||
KGradEpiloguePipeline{},
|
||||
VGradEpiloguePipeline{},
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,12 +15,15 @@ template <typename Problem, typename Policy>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
|
||||
|
||||
public:
|
||||
template <typename... TS>
|
||||
using type_ =
|
||||
std::conditional_t<Problem::kUseTrLoad,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>,
|
||||
std::conditional_t<is_decode,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
|
||||
std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
|
||||
|
||||
@@ -0,0 +1,743 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
{
|
||||
static constexpr auto is_qr_qtr_dor_pipeline = true;
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
|
||||
// using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(kUseTrLoad, "This pipeline uses trload!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
|
||||
return (raw_lse == -numeric<LSEDataType>::infinity()) //
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
else
|
||||
return raw_lse;
|
||||
};
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename KGradDramBlockWindowTmp,
|
||||
typename VGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename QGradEpilogue,
|
||||
typename KGradEpilogue,
|
||||
typename VGradEpilogue,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const KGradDramBlockWindowTmp& dk_dram_block_window_tmp,
|
||||
const VGradDramBlockWindowTmp& dv_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
const QGradEpilogue& dq_epilogue,
|
||||
const KGradEpilogue& dk_epilogue,
|
||||
const VGradEpilogue& dv_epilogue,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
// Early termination
|
||||
const auto [seqlen_kv_start, seqlen_kv_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_kv_end - seqlen_kv_start, kN0);
|
||||
|
||||
// K, HBM ->LDS ->Reg
|
||||
auto k_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
auto v_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, HBM -> LDS --trload-->Reg
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
auto k_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM -->LDS
|
||||
auto q_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
q_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK0>{}),
|
||||
q_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dO: HBM ->LDS ---load--> Reg
|
||||
// dOT: \-loadtr-> Reg
|
||||
auto do_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
|
||||
do_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>());
|
||||
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
|
||||
auto do_lds_write_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
do_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dS: Reg -> Reg -> LDS
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// transform it to make it from col-major to row-major; prepared for load_tile_transpose
|
||||
auto ds_lds_t = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
|
||||
auto ds_lds_read_window =
|
||||
make_tile_window(ds_lds_t,
|
||||
make_tuple(number<kM0>{}, number<kK4>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// Bias: HBM ->Reg ->Reg ->LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_kv_start},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
|
||||
randval_dram_block_window_tmp, seqlen_kv_start);
|
||||
|
||||
// BiasGrad
|
||||
// Reg ->LDS ->Reg ->HBM
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto dbias_dram_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{dbias_origin.at(number<0>{}), seqlen_kv_start}); // M/N
|
||||
|
||||
auto dbias_lds_read_window =
|
||||
make_tile_window(bias_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
auto dk_dram_window = make_tile_window(dk_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dk_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
auto dv_dram_window = make_tile_window(dv_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dv_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_kv_step = seqlen_kv_start;
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
decltype(load_tile(q_lds_read_window)) q_reg_tensor;
|
||||
decltype(load_tile(lse_lds_read_window)) lse;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
|
||||
decltype(load_tile(do_lds_read_window)) do_reg_tensor;
|
||||
decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
|
||||
decltype(load_tile(d_lds_read_window)) d;
|
||||
decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
|
||||
decltype(gemm_0.MakeCBlockTile()) s_acc, p;
|
||||
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
|
||||
decltype(gemm_4.MakeCBlockTile()) dq_acc;
|
||||
clear_tile(dq_acc);
|
||||
|
||||
decltype(load_tile(lse_dram_window)) lse_block_tile;
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
constexpr bool is_main_body = is_prologue && is_epilogue;
|
||||
|
||||
// init VGrad & KGrad
|
||||
decltype(gemm_1.MakeCBlockTile()) dv_acc;
|
||||
decltype(gemm_3.MakeCBlockTile()) dk_acc;
|
||||
|
||||
decltype(load_tile(k_lds_read_window)) k_reg_tensor;
|
||||
decltype(load_tile(v_lds_read_window)) v_reg_tensor;
|
||||
decltype(load_tile_transpose(kt_lds_read_window)) kt_reg_tensor;
|
||||
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
// __builtin_amdgcn_s_waitcnt(0);
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 1, Q@K Gemm0
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(0, seqlen_kv_step, number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto p_spans = decltype(p)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
|
||||
else
|
||||
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
0, seqlen_kv_step, p, randval_dram_window);
|
||||
}
|
||||
const auto p_gemm = [&]() { // dropout / type conversion
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) {
|
||||
return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
|
||||
},
|
||||
p);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(p);
|
||||
}
|
||||
}();
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
|
||||
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
|
||||
|
||||
dv_acc = gemm_1(pt_reg_tensor, dot_reg_tensor);
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbias = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
ds);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
|
||||
}
|
||||
});
|
||||
move_tile_window(ds_lds_read_window, {-kN0, 0});
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
dk_epilogue(dk_dram_window, dk_acc);
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, dv_acc);
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
};
|
||||
|
||||
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
|
||||
{
|
||||
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
|
||||
main_body(std::true_type{}, std::false_type{});
|
||||
// Hot loop
|
||||
if(num_total_loop > 1)
|
||||
{
|
||||
do
|
||||
{
|
||||
main_body(std::true_type{}, std::true_type{});
|
||||
i_total_loops += 1;
|
||||
seqlen_kv_step += kN0;
|
||||
} while(i_total_loops < num_total_loop - 1);
|
||||
}
|
||||
main_body(std::false_type{}, std::true_type{});
|
||||
seqlen_kv_step += kN0;
|
||||
|
||||
const auto k_length = k_dram_block_window_tmp.get_window_lengths();
|
||||
const auto seqlen_kv_length = k_length.at(number<0>{});
|
||||
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
|
||||
{
|
||||
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
else
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
// static_assert(kIsDeterministic);
|
||||
dq_epilogue(dq_dram_window, dq_acc);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline;
|
||||
} // namespace ck_tile
|
||||
@@ -65,7 +65,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -73,7 +74,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
|
||||
false,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
SwizzleA>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
|
||||
@@ -105,16 +106,19 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
(Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -293,26 +297,29 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kWarps = kBlockSize / get_warp_size();
|
||||
|
||||
constexpr index_t K2 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2;
|
||||
constexpr index_t K0 = ColsPerBlock / K1 / K2;
|
||||
static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes,
|
||||
constexpr index_t K3 = GetAlignmentK<Problem>(); // 8
|
||||
constexpr index_t K2 = WarpAlignmentBytes / sizeof(T) / K3; // 8
|
||||
constexpr index_t K_remain = ColsPerBlock / K2 / K3;
|
||||
constexpr index_t K1 = min(kWarps, K_remain);
|
||||
constexpr index_t K0 = K_remain / K1;
|
||||
static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) &&
|
||||
K2 * K3 * sizeof(T) == WarpAlignmentBytes,
|
||||
"ColsPerBlock notdivisible");
|
||||
|
||||
constexpr index_t N2 = get_warp_size() / K1;
|
||||
constexpr index_t N1 = kWarps / K0;
|
||||
constexpr index_t N2 = get_warp_size() / K2; // 8
|
||||
constexpr index_t N1 = max(1, kWarps / K1);
|
||||
constexpr index_t N0 = RowsPerBlock / N1 / N2;
|
||||
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) &&
|
||||
(K1 * N2 == get_warp_size()),
|
||||
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) &&
|
||||
(K2 * N2 == get_warp_size()),
|
||||
"RowsPerBlock not divisible");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2, 1>, sequence<1, 2>>, // K0 N1, N2 K1
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
sequence<1, 2>, // N0 K2
|
||||
sequence<0, 2>>{});
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2, 1>, sequence<1, 2>>, // K1 N1, N2 K2
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 2, 2>, // N0 K0 K3
|
||||
sequence<0, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -961,13 +968,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentBias<Problem>();
|
||||
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
|
||||
kMPerBlock * kNPerBlock / kBlockSize);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M2 = kMPerBlock / M1 / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,7 +74,8 @@ template <typename BlockTile_, // sequence<...
|
||||
typename Gemm3BlockWarps_,
|
||||
typename Gemm3WarpTile_,
|
||||
typename Gemm4BlockWarps_,
|
||||
typename Gemm4WarpTile_>
|
||||
typename Gemm4WarpTile_,
|
||||
index_t kMaxSeqLenQ_ = 0>
|
||||
struct TileFmhaBwdShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
@@ -111,6 +112,10 @@ struct TileFmhaBwdShape
|
||||
// K/K^T at once
|
||||
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
|
||||
// that need load V at once
|
||||
|
||||
static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
|
||||
static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
|
||||
"kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user