[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 05292b3604]
This commit is contained in:
Po Yen Chen
2025-12-05 10:31:12 +08:00
committed by GitHub
parent 96ff482d8d
commit d96f632fa1
22 changed files with 890 additions and 1449 deletions

View File

@@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -30,16 +30,24 @@ _MASK_MAP = {
}
def get_mask_map(mask: str):
if mask == "generic":
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask == "simplified":
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
@@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
@@ -122,6 +134,7 @@ PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
}
PIPELINE_ENUM_MAP = {
@@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = {
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
}
BOOL_MAP = {

File diff suppressed because it is too large Load Diff

View File

@@ -1,616 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
const auto is_causal = args.get_bool("causal");
if(is_causal)
{
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
}
else
{
mask = mask_info::decode("0", seqlen_q, seqlen_k);
}
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
q_eff_lens = args.get_int_vec("q_eff_lens");
kv_eff_lens = args.get_int_vec("kv_eff_lens");
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
std::vector<int> q_eff_lens;
std::vector<int> kv_eff_lens;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
// Ensure output buffer is zero-initialized so padded regions compare cleanly
o_buf.SetZero();
ck_tile::fmha_fwd_v3_args args{};
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
// Optional cumulative seqlen overrides (exclude PAD)
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
std::vector<ck_tile::index_t> eff;
if(!opt_vec.empty() && opt_vec[0] != -1)
{
eff.assign(opt_vec.begin(), opt_vec.end());
if(eff.size() < static_cast<size_t>(problem.batch))
{
eff.resize(problem.batch, eff.back());
}
}
else
{
eff.assign(problem.batch, fallback);
}
return eff;
};
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
cum_vec.resize(per_batch_vec.size() + 1);
cum_vec[0] = 0;
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
};
if(has_varlen_q)
{
calculate_cumulative(eff_q_vec, cuq_cum);
}
if(has_varlen_k)
{
calculate_cumulative(eff_kv_vec, cukv_cum);
}
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
args.cu_seqlen_q_ptr =
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
: nullptr;
args.cu_seqlen_kv_ptr =
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
: nullptr;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template <typename FmhaKernel>
auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
{
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
/// maximizes the kernel's performance.
int remap_opt = 2;
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
{
if(65536 <= args.seqlen_q)
{
remap_opt = 0;
}
else
{
remap_opt = 1;
}
}
auto kargs = [&] {
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
}();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaKernel>
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
{

View File

@@ -1,60 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
{
switch(data_type)
{
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
{
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -1,73 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct fmha_fwd_v3_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t batch;
index_t seqlen_q;
index_t seqlen_k;
index_t nhead_q;
index_t nhead_kv;
index_t hdim_qk;
index_t hdim_v;
float softmax_scale;
index_t window_size_left;
index_t window_size_right;
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
// window_size_right == 0).
const void* q_ptr;
index_t stride_q;
index_t nhead_stride_q;
index_t batch_stride_q;
const void* k_ptr;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
const void* v_ptr;
index_t stride_v;
index_t nhead_stride_v;
index_t batch_stride_v;
void* o_ptr;
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
// If provided, they override per-batch effective lengths to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -1,179 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
using fmha_shape = TileFmhaShape<fmha_block_tile,
fmha_block_warps,
fmha_warp_gemm_shape,
fmha_block_warps,
fmha_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
true, // kPadSeqLenK
false, // kPadHeadDimQ
false, // kPadHeadDimV
false, // kStoreLSE
-1 // kBlockPerCu
>;
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using fmha_pipeline_problem =
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
fmha_shape,
IsVariableSeqlen,
fmha_mask,
fmha_traits>;
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
};
template <typename Kernel>
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
{
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
/// maximizes the kernel's performance.
int remap_opt = 2;
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
{
if(65536 <= args.seqlen_q)
{
remap_opt = 0;
}
else
{
remap_opt = 1;
}
}
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_qk,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_kv,
args.softmax_scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>&
printf("}");
}
template <typename Functor, typename LowLength>
struct functor_transform : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
Functor functor_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr functor_transform() = default;
CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor,
const LowLength& low_length)
: functor_{functor}, up_lengths_{make_tuple(low_length)}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = functor_(idx_up[number<0>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& up_idx) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
const auto idx_low_old = idx_low;
calculate_lower_index(idx_low, up_idx);
idx_diff_low = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// Note: When using functor_transform, ensure that the transformed coordinates
// are always valid for vectorized load/store operations.
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
};
//*******************************************************************************************************
template <typename LowLength>
@@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
return offset<LowLength, OffsetLength>{low_length, offset_length};
}
template <typename Functor, typename LowLength>
CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor,
const LowLength& low_length)
{
return functor_transform<Functor, LowLength>{functor, low_length};
}
} // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"

View File

@@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths
}
};
template <typename TensorView_, typename WindowLengths_>
template <typename TensorView_,
typename WindowLengths_,
typename = std::enable_if_t<is_tensor_view_v<TensorView_>>>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
@@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename = std::enable_if_t<is_tile_distribution_v<StaticTileDistribution>>>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,

View File

@@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask
mdiv y_ratio_mdiv;
};
template <typename>
struct is_generic_attention_mask : std::false_type
{
};
template <bool IsMasking, bool IsLocal>
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
{
};
template <typename Mask>
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask

View File

@@ -73,54 +73,6 @@ struct FmhaFwdKernel
#endif
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
// clang-format off
template <typename T1, typename T2 = T1> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs

View File

@@ -12,6 +12,8 @@
namespace ck_tile {
/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and
/// instruction scheduling optimizations.
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdV3Kernel
{
@@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
};
struct FmhaFwdGroupModeKargs
@@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel
kargs.batch_stride_lse = batch_stride_lse;
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel
{}, // placeholder for lse
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasMask)
@@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
if constexpr(kHasMask)
if constexpr(kIsGroupMode)
{
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
return dim3(nhead,
batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1));
}
else
{
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
return dim3(nhead,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
batch_size);
}
}
@@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel
// FmhaPipeline::kN1);
// assume that num_tile_n1 is always 1
if constexpr(kHasMask)
if constexpr(kIsGroupMode)
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_batch = blockIdx.y;
const index_t i_block = blockIdx.z;
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}
else
{
@@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}
}
@@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
if constexpr(kStoreLSE)
{
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_lse = query_start;
}
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
@@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
}
else
@@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
@@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
}
} // namespace detail
/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and
/// instruction scheduling optimizations.
template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
struct BlockFmhaFwdV3Pipeline
{
@@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
static_assert(is_generic_attention_mask_v<FmhaMask>);
static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
"we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
using BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
@@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline
static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kStoreLSE && !kHasDropout &&
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
!kSkipMinSeqlenQ),
"enable unsupported features");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum
QRKSVS_ASYNC,
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
};
template <BlockFmhaPipelineEnum>

View File

@@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaFwdV3PipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdV3Traits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile

View File

@@ -90,7 +90,7 @@ submodule = submodule_t()
# formatting
format_procs = []
for x in all_files:
dos2unix = f"python -m dos2unix {str(x)} {str(x)}"
dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}"
clang_format = f"clang-format -style=file -i {str(x)}"
# One process to avoid race conditions.
cmd = f"{dos2unix} && {clang_format}"