mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Add qr_ks_vs_whole_k_prefetch_trload pipeline
This commit is contained in:
@@ -54,6 +54,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -235,11 +236,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kNPerBlock * kKPerBlock;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
@@ -958,7 +966,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
WarpGemm>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm1Warps)
|
||||
return BlockGemmARegBSmemCRegV2PrefetchN<GemmProblem, BlockGemmPolicy>{};
|
||||
{
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
return BlockGemmARegBSmemCRegV2PrefetchN<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemTrLoadCRegV2PrefetchN<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,835 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr bool kUseN0Loop = true;
|
||||
static constexpr bool kIgnoreFastExp2 = true;
|
||||
static constexpr bool kIsNaiveHDimLoad = true;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kN0Sub = BlockFmhaShape::kN0Sub;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& /* unused */,
|
||||
const AttentionVariantParams& /* unused */,
|
||||
const BlockIndices& /* unused */,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
// xformers path does not require the pipeline to output random values for host
|
||||
// verification, since a separate kernel is used to generate random values
|
||||
ignore = randval_dram_block_window_tmp;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
// usually kN0 is 128, kN0Sub/kK1 is 32/16
|
||||
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
|
||||
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
|
||||
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
Policy::template IsPreloadWholeNextIterationK<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
|
||||
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
|
||||
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto q_dram_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
auto k_tiles = [&]() {
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
return statically_indexed_array<k_tile_type, n0_loops>{};
|
||||
else
|
||||
return statically_indexed_array<k_tile_type, 1>{};
|
||||
}();
|
||||
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start},
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
// assuming no random values need be saved, this is true when the pipeline is called from
|
||||
// xformers, since we have a separate kernel to generated randomm values
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to pass a null_randval_dram and tile window to the BlockDropout operator to
|
||||
// make it works
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
store_tile(q_lds_write_window, q_dram_tile, partition_index);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64
|
||||
{
|
||||
if(seqlen_k_curr == seqlen_k_start) // at first iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
else // at intermediate and last iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
// prefetch k_tile for next iteration
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
else // only preload one unroll of K for next iteration, used when kM0=128
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x000000001);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto y) {
|
||||
x += type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
pcomp_tile(i_j_idx) *= scale_s;
|
||||
position_encoding.update(pcomp_tile(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(pcomp_tile, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p =
|
||||
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]),
|
||||
partition_index);
|
||||
};
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
// store lse
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
|
||||
BDataType>::TransposedDstrEncode{};
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
|
||||
make_static_tile_distribution(b_warp_dstr_encode));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
|
||||
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user