mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Add init codegen logic for fmha fwd appendkv
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
@@ -22,6 +24,9 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
378
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
Normal file
378
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
Normal file
@@ -0,0 +1,378 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_>
|
||||
struct FmhaFwdAppendKVKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using gbr = typename bfs::Gemm0BlockWarps;
|
||||
using gwt = typename bfs::Gemm0WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadSeqLenK) n += "sk";
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct EmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct CommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
};
|
||||
|
||||
struct BatchModeKargs : CommonKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
};
|
||||
|
||||
struct GroupModeKargs : CommonKargs
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v}, // args for common karg
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v}, // args for common karg
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(FmhaPipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
|
||||
FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdAppendKVTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
static constexpr const char* name = "shb";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,124 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
|
||||
struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = typename Problem::QDataType;
|
||||
using KDataType = typename Problem::KDataType;
|
||||
using VDataType = typename Problem::VDataType;
|
||||
|
||||
using BlockFmhaShape = typename Problem::BlockFmhaShape;
|
||||
using VLayout = typename BlockFmhaShape::VLayout;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
(void)q_dram_block_window_tmp;
|
||||
(void)q_element_func;
|
||||
(void)k_dram_block_window_tmp;
|
||||
(void)k_element_func;
|
||||
(void)v_dram_block_window_tmp;
|
||||
(void)v_element_func;
|
||||
(void)smem_ptr;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
return 16 / sizeof(QDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
return 16 / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// TODO: not correct!
|
||||
if constexpr(total_pixels > 4)
|
||||
return 4;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(VDataType);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -32,6 +32,20 @@ struct TileFmhaTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdAppendKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
|
||||
Reference in New Issue
Block a user