mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
ck tile pagedkv prefill (#2405)
* add prefetching physical block id for pagedkv * start add pagedkv prefill * rename pipeline * add kernel for pagedkv * add an init version pagedkv prefill * fix redefine issue * add struct BlockFmhaFwdPagedKVPipelineProblem and fmha_fwd_pagedkv_args * generate dispatch code * add body generating code * comipling pass * remove dropout from pagedkv * set lse to false in generating code * start changing qr kernel to pagedkv * init version of kernerl with pagedkv * change names of file that are generated * chang host validation for pagedkv prefill * using iglp to change blockgemm * add kernel files to op head file * show parameters * rewrite print parameter fun * add fwd * remove default parameter of GridSize * format * fix nhead issue and add seqlen_k_ptr to batch mode * format code * remove no-longer used code * format * fix some comments --------- Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
@@ -34,6 +35,8 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
@@ -51,6 +51,27 @@ struct TrivialPageBlockNavigator
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t /*id*/) const
|
||||
{
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t /*block_index*/,
|
||||
TileWindow /*tile_window*/,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
@@ -153,6 +174,56 @@ struct PageBlockNavigator
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t id) const
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
if(id >= 0)
|
||||
tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
|
||||
fixed_offset);
|
||||
else
|
||||
tile_window.set_bottom_tensor_view_data_ptr(nullptr);
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
auto local_tile_window = tile_window; // not affect origin window
|
||||
ck_tile::move_tile_window(local_tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, local_tile_window.get_window_origin());
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
|
||||
if(new_block_index < num_blocks)
|
||||
{
|
||||
return physical_block_indices[new_block_index];
|
||||
}
|
||||
else
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
|
||||
{
|
||||
return block_index == num_blocks - 1;
|
||||
|
||||
1374
include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp
Normal file
1374
include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,751 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: This class is a variant of the existing BlockFmhaFwdSplitKVPipelineQRKSVS pipeline.
|
||||
// Refactoring to extract shared logic is recommended as future work.
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_pagedkv";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
auto q = load_tile(q_dram_window);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
// k_dram_block_window
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
// divisible by kN0)
|
||||
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start =
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
// v_dram_window
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc); // initialize C
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
auto physical_next_block_id_k =
|
||||
__builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
// position_encoding accept only logical coordinates, do conversion here
|
||||
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
[&variant, &variant_params, &block_indices](auto& x) {
|
||||
x = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, x),
|
||||
block_indices.batch_idx,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout(apply_logits_transform, s_acc);
|
||||
#else
|
||||
tile_elementwise_inout(apply_logits_transform, s_acc);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col - kv_l2p_offset,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&,
|
||||
&i_page_block_v_ = i_page_block_v,
|
||||
&v_dram_window_ = v_dram_window](auto i_k1) {
|
||||
auto physical_next_block_id_v_ =
|
||||
__builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}));
|
||||
const auto v = load_tile(v_dram_window_); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
i_page_block_v_ = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{},
|
||||
v_dram_block_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
{
|
||||
if constexpr(128 >= Problem::BlockFmhaShape::kK0)
|
||||
return BlockGemmARegBSmemCRegV2R1<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -320,6 +320,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
auto physical_next_block_id_k =
|
||||
__builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -600,8 +605,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
@@ -612,6 +617,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static_for<0, k1_loops - 1, 1>{}([&,
|
||||
&i_page_block_v_ = i_page_block_v,
|
||||
&v_dram_window_ = v_dram_window](auto i_k1) {
|
||||
auto physical_next_block_id_v_ =
|
||||
__builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}));
|
||||
const auto v = load_tile(v_dram_window_); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
@@ -634,12 +642,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
i_page_block_v_ = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1});
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -61,6 +61,58 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdPagedKVPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
@@ -37,6 +37,34 @@ struct TileFmhaTraits
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kHasLogitsSoftCap_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kIsPagedKV_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaFwdPagedKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2R1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(decltype(b_warp_window_tmp){})),
|
||||
KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensors;
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user