mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -8,21 +8,16 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
|
||||
@@ -286,11 +286,226 @@ struct BlockDropout
|
||||
});
|
||||
}
|
||||
|
||||
ck_tile::philox ph;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd;
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = false;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
(void)randval_dram_block_window_tmp;
|
||||
(void)seqlen_qk_start;
|
||||
|
||||
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = true;
|
||||
// true: 32*32 warp gemm
|
||||
// false: 16*16 warp gemm
|
||||
static constexpr bool IsWG32 = IsWG32_;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
|
||||
index_t i_head,
|
||||
index_t nheads,
|
||||
unsigned long long seed,
|
||||
unsigned long long offset,
|
||||
float rp_undrop_,
|
||||
uint8_t p_undrop_in_uint8_t_)
|
||||
: ph(seed,
|
||||
offset + (i_batch * nheads + i_head) * get_warp_size() +
|
||||
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
|
||||
rp_undrop(rp_undrop_),
|
||||
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
|
||||
constexpr index_t kMPerStep = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return MWarp * WG::kM * 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return MWarp * WG::kM;
|
||||
}
|
||||
}();
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = WG::kN;
|
||||
constexpr index_t kN1 = 8;
|
||||
constexpr index_t kN0 = kNPerStep / kN1;
|
||||
|
||||
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
|
||||
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
|
||||
number<kN1>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
|
||||
randval_lds_block_desc_0,
|
||||
ck_tile::make_tuple(
|
||||
make_pass_through_transform(number<kMPerStep>{}),
|
||||
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
|
||||
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return randval_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
|
||||
|
||||
constexpr index_t MIterPerWarp = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
|
||||
// except headdim256.
|
||||
constexpr auto randval_block_inner_part_dstr_encoding = []() {
|
||||
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::CDataType, float>)
|
||||
{
|
||||
if constexpr(IsWG32)
|
||||
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
else
|
||||
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsWG32)
|
||||
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
else
|
||||
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
randval_block_inner_part_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename PComputeDataType,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
|
||||
const index_t start_m0_idx,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
@@ -305,30 +520,177 @@ struct BlockDropout
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval.kThreadElementSpaceSize == 16);
|
||||
// randval tile in LDS
|
||||
auto randval_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
|
||||
|
||||
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
|
||||
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
|
||||
auto randval_lds_window = make_tile_window(
|
||||
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
|
||||
|
||||
// register distribute
|
||||
auto randval_dist_generated =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
|
||||
auto randval_lds_read_window =
|
||||
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
|
||||
randval_lds_window.get_window_lengths(),
|
||||
randval_lds_window.get_window_origin(),
|
||||
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
|
||||
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
|
||||
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
constexpr auto randval_dist_generated_spans =
|
||||
decltype(randval_dist_generated)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
|
||||
});
|
||||
});
|
||||
// save to LDS
|
||||
store_tile(randval_lds_window, randval_dist_generated);
|
||||
block_sync_lds();
|
||||
// read from LDS to register
|
||||
auto randval = load_tile(randval_lds_read_window);
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
|
||||
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
|
||||
constexpr index_t kMPerStep = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return MWarp * WG::kM * 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return MWarp * WG::kM;
|
||||
}
|
||||
}();
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval = make_static_distributed_tensor<uint8_t>(
|
||||
MakeRandValTileDistribution<BlockGemm, false>());
|
||||
if constexpr(IsWG32)
|
||||
static_assert(randval.kThreadElementSpaceSize == 16);
|
||||
else
|
||||
static_assert(randval.kThreadElementSpaceSize == 4 ||
|
||||
randval.kThreadElementSpaceSize == 8);
|
||||
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
int block_row_start, block_col_start;
|
||||
if constexpr(IsWG32)
|
||||
{
|
||||
block_row_start = (start_m0_idx / WG::kM) + i_m0;
|
||||
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
|
||||
}
|
||||
else
|
||||
{
|
||||
block_row_start = start_m0_idx / 32 + i_m0;
|
||||
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
|
||||
}
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t* random_uint8_t_;
|
||||
if constexpr(MBwdWG16SingleIterCheck)
|
||||
{
|
||||
uint8_t random_uint8_t[4];
|
||||
// m0t0 ~m0t15/m0t32~m0t47: 0
|
||||
// m0t16~m0t31/m0t48~m0t63: 1
|
||||
// m1t0 ~m1t15/m1t32~m1t47: 2
|
||||
// m1t16~m1t31/m1t48~m1t63: 3
|
||||
const index_t start_idx =
|
||||
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
|
||||
ph.get_random_4x8(
|
||||
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
else if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
uint8_t random_uint8_t[8];
|
||||
// t0 ~t15/t32~t47: 0
|
||||
// t16~t31/t48~t63: 1
|
||||
const index_t start_idx = (get_lane_id() >> 4) & 1;
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t,
|
||||
reinterpret_cast<unsigned long long&>(rowcol));
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval(r_idx) = random_uint8_t[i_random_idx++];
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval(r_idx) = random_uint8_t_[i_random_idx++];
|
||||
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
|
||||
idx0.impl_.at(1),
|
||||
idx0.impl_.at(2)>{};
|
||||
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
@@ -337,19 +699,19 @@ struct BlockDropout
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {kMPerStep, 0});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
|
||||
}
|
||||
@@ -358,7 +720,6 @@ struct BlockDropout
|
||||
ck_tile::philox ph;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,54 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaBwdTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
template <ck_tile::index_t kBlockSize>
|
||||
struct FmhaBwdOGradDotOTilePartitioner
|
||||
{
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
|
||||
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
_SS_(FmhaPipeline::name) +
|
||||
(pn.empty() ? "" : "_" + pn) +
|
||||
(kStoreLSE ? "_lse" : "" ) +
|
||||
(kStoreLSE ? "_lse" : "" ) +
|
||||
(kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o};
|
||||
batch_stride_o,
|
||||
batch_stride_lse_acc};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
const long_index_t batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_lse_acc = query_start;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
batch_stride_v,
|
||||
batch_stride_lse_acc};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
const long_index_t batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_lse_acc = query_start;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
|
||||
141
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
Normal file
141
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
Normal file
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdConvertQGrad
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
|
||||
static constexpr index_t kM0 = Problem::kM0;
|
||||
static constexpr index_t kN0 = Problem::kN0;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
|
||||
static constexpr index_t kAlignmentQGradAcc =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
// Convert only
|
||||
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
|
||||
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<AccDataType,
|
||||
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
|
||||
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradDramTileDistribution<Problem>());
|
||||
|
||||
auto dq_acc = load_tile(dq_acc_dram_window);
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
}
|
||||
|
||||
// Reduce + Convert
|
||||
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
|
||||
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
index_t nsplits) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<AccDataType,
|
||||
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
|
||||
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
|
||||
|
||||
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
|
||||
clear_tile(dq_acc);
|
||||
|
||||
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
|
||||
index_t i_total_loops = 0;
|
||||
auto dq_acc_buf = load_tile(dq_acc_dram_window);
|
||||
move_tile_window(dq_acc_dram_window, {1, 0, 0});
|
||||
|
||||
do
|
||||
{
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
dq_acc_buf = load_tile(dq_acc_dram_window);
|
||||
move_tile_window(dq_acc_dram_window, {1, 0, 0});
|
||||
|
||||
i_total_loops += 1;
|
||||
} while(i_total_loops < (nsplits - 1));
|
||||
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// declare dq
|
||||
constexpr auto dq_converted_dstr =
|
||||
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
|
||||
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
|
||||
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
|
||||
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
|
||||
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
|
||||
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,11 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdOGradDotO
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// These templates are not used here.
|
||||
using BlockFmhaBwdOGradDotODefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ false,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ false,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,782 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
|
||||
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
// K, HBM ->LDS ->Reg
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
// Early termination
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
VDataType* v_lds_ptr =
|
||||
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
|
||||
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(kt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
auto v_block_tile = load_tile(v_dram_window);
|
||||
|
||||
store_tile(k_lds_write_window, k_block_tile);
|
||||
shuffle_tile(shuffled_k_block_tile, k_block_tile);
|
||||
store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto kt_reg_tensor = load_tile(kt_lds_read_window);
|
||||
|
||||
store_tile(v_lds_write_window, v_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>()));
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK0>{}),
|
||||
q_lds_window.get_window_origin(),
|
||||
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
|
||||
// QT: Reg -> Reg-> LDS
|
||||
auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
QDataType* qt_lds_ptr =
|
||||
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(qt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dO: HBM ->Reg ->LDS
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>());
|
||||
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
|
||||
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
do_lds_window.get_window_origin(),
|
||||
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
|
||||
// dOT: Reg ->Reg ->LDS
|
||||
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
|
||||
auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(dot_read_lds,
|
||||
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dS: Reg -> Reg -> LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>()));
|
||||
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto ds_lds_read_window =
|
||||
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK4>{}),
|
||||
ds_lds_window.get_window_origin(),
|
||||
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
|
||||
// Bias: HBM ->Reg ->Reg ->LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>()));
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
// BiasGrad
|
||||
// Reg ->LDS ->Reg ->HBM
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto dbias_dram_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto dbias_lds_read_window =
|
||||
make_tile_window(bias_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Hot loop
|
||||
while(i_total_loops < num_total_loop)
|
||||
{
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
auto lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
shuffle_tile(shuffled_q_block_tile, q_block_tile);
|
||||
store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
|
||||
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto q_reg_tensor = load_tile(q_lds_read_window);
|
||||
auto lse = load_tile(lse_lds_read_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto s_acc = SPBlockTileType{};
|
||||
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto p = SPBlockTileType{};
|
||||
constexpr auto p_spans = decltype(p)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
|
||||
}
|
||||
const auto p_gemm = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
p);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(p);
|
||||
}
|
||||
}();
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto do_block_tile = load_tile(do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
auto d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
shuffle_tile(shuffled_do_block_tile, do_block_tile);
|
||||
store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
|
||||
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem,
|
||||
decltype(pt_reg_tensor),
|
||||
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto do_reg_tensor = load_tile(do_lds_read_window);
|
||||
auto d = load_tile(d_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbias = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
ds);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto ds_reg_tensor = load_tile(ds_lds_read_window);
|
||||
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
|
||||
}
|
||||
});
|
||||
move_tile_window(ds_lds_read_window, {0, -kN0});
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
|
||||
i_total_loops += 1;
|
||||
seqlen_q_step += kM0;
|
||||
}
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,848 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = true;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_kts_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType,
|
||||
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto kt_dram_block_window = kt_dram_block_window_tmp;
|
||||
|
||||
auto kt_dram_window = make_tile_window(
|
||||
kt_dram_block_window.get_bottom_tensor_view(),
|
||||
kt_dram_block_window.get_window_lengths(),
|
||||
kt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
|
||||
// load
|
||||
|
||||
auto kt_block_tile = load_tile(kt_dram_window);
|
||||
|
||||
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
|
||||
|
||||
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k & k^t located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ true,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,821 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,692 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = true;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "qs_ks_vr_dos";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
}
|
||||
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
block_sync_lds();
|
||||
store_tile(do_lds_window, do_block_tile); // store the prefetch
|
||||
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
get_slice_tile(dot_lds_window,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
get_slice_tile(do_lds_window,
|
||||
sequence<0, i_k2 * kK2>{},
|
||||
sequence<kM0, (i_k2 + 1) * kK2>{}),
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
block_sync_lds();
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
get_slice_tile(qt_lds_window,
|
||||
sequence<0, i_k3 * kK3>{},
|
||||
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, q & k & do located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ true,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,9 +8,8 @@ namespace ck_tile {
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KSKTSVR = 0,
|
||||
QSKSVROGradS,
|
||||
KSVR,
|
||||
KRKTRVR_IGLP = 0,
|
||||
KRKTRVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,7 +24,9 @@ template <typename QDataType_,
|
||||
typename BiasGradDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdPipelineProblem
|
||||
{
|
||||
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
|
||||
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename AccDataType_,
|
||||
typename QGradDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kQKHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using QGradDataType = remove_cvref_t<QGradDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
"kBlockSize should be divisible by get_warp_size()");
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kQKHeaddim = kQKHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
struct TileFmhaBwdConvertQGradTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
Normal file
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBRegCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
// M->N Warp
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_>
|
||||
struct BlockGemmARegBRegCRegV1CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmARegBRegCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmARegBRegCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
|
||||
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
|
||||
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
|
||||
|
||||
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
@@ -58,8 +65,15 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
using WarpGemmMfmaDispatcher = typename impl::
|
||||
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user