[CK_TILE] Add FAv3 fwd pipeline (#2731)

* Add FAv3 fwd pipeline

* Unpack v_pk_mul to hide v_mov

* Avoid compiler moving l compute across phase

* Sync sched_group_barrier() setting for masking cases

[ROCm/composable_kernel commit: d876e87fe4]
This commit is contained in:
Po Yen Chen
2025-09-01 09:16:45 +08:00
committed by GitHub
parent 06fad18aaf
commit 884ecea071
16 changed files with 3270 additions and 0 deletions

View File

@@ -144,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -0,0 +1,492 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k);
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
ck_tile::fmha_fwd_v3_args args;
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
{
switch(data_type)
{
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
{
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct fmha_fwd_v3_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t batch;
index_t seqlen_q;
index_t seqlen_k;
index_t nhead_q;
index_t nhead_kv;
index_t hdim_qk;
index_t hdim_v;
float softmax_scale;
index_t window_size_left;
index_t window_size_right;
index_t mask_type;
const void* q_ptr;
index_t stride_q;
index_t nhead_stride_q;
index_t batch_stride_q;
const void* k_ptr;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
const void* v_ptr;
index_t stride_v;
index_t nhead_stride_v;
index_t batch_stride_v;
void* o_ptr;
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,159 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "fmha_fwd_v3.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
using fmha_shape = TileFmhaShape<fmha_block_tile,
fmha_block_warps,
fmha_warp_gemm_shape,
fmha_block_warps,
fmha_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
true, // kPadSeqLenK
false, // kPadHeadDimQ
false, // kPadHeadDimV
false, // kStoreLSE
-1 // kBlockPerCu
>;
using fmha_mask = SimplifiedGenericAttentionMask<IsMasking>;
using fmha_pipeline_problem =
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
fmha_shape,
IsVariableSeqlen,
fmha_mask,
fmha_traits>;
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
};
template <typename Kernel>
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
{
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_qk,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_kv,
args.softmax_scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,31 @@
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
VALID=0
for causal in 0 1 ; do
for prec in "fp16" "bf16" ; do
for hdim in 128 ; do
for perm in 0 ; do
if [ $causal -eq 0 ]; then
mask=0
else
mask=b:-1,0
fi
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
done
done
done
done

View File

@@ -18,6 +18,7 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
@@ -40,6 +41,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"

View File

@@ -0,0 +1,519 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <type_traits>
#include <utility>
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdV3Kernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaFwdCommonKargs
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
};
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
using namespace ck_tile;
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
using namespace ck_tile;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
batch_offset_o = query_start * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}();
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
{0, i_n1});
// lse
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<1>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
smem_ptr);
}();
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,603 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
namespace ck_tile {
struct BlockFmhaV3PipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 4;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
{
using namespace ck_tile;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
{
using namespace ck_tile;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
// compute the endcoding before transpose
constexpr auto v_block_dstr =
make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(v_block_dstr_encode),
typename Problem::VDataType>::TransposedDstrEncode{});
return v_block_dstr;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here
return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{};
}
}();
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm),
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
using WarpGemm = WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return v_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
using namespace ck_tile;
static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t k_element_space_size =
MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t v_element_space_size =
MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
GetSingleSmemElementSpaceSize<Problem>());
/// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() &
/// MakeVLdsBlockDescriptor()
static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
constexpr index_t kv_element_space_size_in_bytes =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
return kv_element_space_size_in_bytes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 4 * GetSmemSizeKV<Problem>();
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
@@ -262,4 +263,47 @@ struct BlockFmhaFwdAppendKVPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaFwdV3PipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -148,4 +148,20 @@ struct TileFmhaBwdConvertQGradTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdV3Traits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile