From 4420c881023afce36dc5dcb221ec2d0fdf89b314 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 1 Sep 2025 01:41:19 +0000 Subject: [PATCH] Merge commit 'd876e87fe45a58ab4f83b945a021ea5effb9b31d' into develop --- example/ck_tile/01_fmha/CMakeLists.txt | 22 + .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 492 +++++++ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 60 + example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 67 + example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 159 +++ .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 + .../01_fmha/script/benchmark_fwd_v3.sh | 31 + include/ck_tile/ops/fmha.hpp | 3 + .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 519 +++++++ .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 1198 +++++++++++++++++ ...ck_fmha_fwd_v3_pipeline_default_policy.hpp | 603 +++++++++ .../pipeline/block_fmha_pipeline_problem.hpp | 44 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 + 16 files changed, 3270 insertions(+) create mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp create mode 100755 example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index bd03aee924..5f495c76d8 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -144,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) +# add fmha_fwd_v3 example +set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") + +add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) +target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE + fmha_fwd_v3.cpp + ${FMHA_FWD_V3_INSTANCES} +) + +set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) +list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -fgpu-flush-denormals-to-zero + -Wno-undefined-func-template + --save-temps +) +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) + # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp new file mode 100644 index 0000000000..d2428e5152 --- /dev/null +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -0,0 +1,492 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmha_fwd.hpp" +#include "fmha_fwd_v3.hpp" +#include "mask.hpp" + +auto parse_cmd_args(int argc, char* argv[]) -> std::pair +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("prec", "fp16", "data type. fp16/bf16") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q & k") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "0", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "0", "permute output") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("v", "1", "0:no verify, 1:verify") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "30", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +enum class TensorLayout +{ + bhsd, + bshd, +}; + +std::ostream& operator<<(std::ostream& stream, TensorLayout layout) +{ + switch(layout) + { + case TensorLayout::bhsd: return stream << "bhsd"; + case TensorLayout::bshd: return stream << "bshd"; + default: return stream << "unknown"; + } +} + +struct Problem +{ + explicit Problem(const ck_tile::ArgParser& args) + { + data_type = args.get_str("prec") == "fp16" + ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 + : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; + batch = args.get_int("b"); + seqlen_q = args.get_int("s"); + seqlen_k = args.get_int("s_k"); + if(seqlen_k < 0) + { + seqlen_k = seqlen_q; + } + nhead_q = args.get_int("h"); + nhead_kv = args.get_int("h_k"); + if(nhead_kv < 0) + { + nhead_kv = nhead_q; + } + hdim = args.get_int("d"); + softmax_scale = args.get_float("scale_s"); + if(softmax_scale == .0f) + softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); + mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + } + + std::vector get_query_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + std::vector get_key_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_value_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_output_shape() const + { + if(output_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + ck_tile::fmha_fwd_v3_args::data_type_enum data_type; + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_kv; + ck_tile::index_t hdim; + float softmax_scale; + mask_info mask; + TensorLayout input_layout; + TensorLayout output_layout; +}; + +struct RunConfig +{ + explicit RunConfig(const ck_tile::ArgParser& args) + { + seed = args.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + kernel_warmup = args.get_int("warmup"); + kernel_repeat = args.get_int("repeat"); + verify = args.get_bool("v"); + } + + std::optional seed; + int kernel_warmup; + int kernel_repeat; + bool verify; +}; + +template +auto generate_qkv(const Problem& problem, + [[maybe_unused]] std::optional seed = std::nullopt) + -> std::tuple, + ck_tile::HostTensor, + ck_tile::HostTensor> +{ + ck_tile::HostTensor q(problem.get_query_shape()); + ck_tile::HostTensor k(problem.get_key_shape()); + ck_tile::HostTensor v(problem.get_value_shape()); + + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); + + return std::make_tuple(q, k, v); +} + +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ + const int batch_size = q_bshd.mDesc.get_lengths()[0]; + const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; + const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; + const int nhead_q = q_bshd.mDesc.get_lengths()[2]; + const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; + const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; + const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + + const int nr = nhead_q / nhead_kv; + + ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); + ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); + ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); + ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); + ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + + // do computation for each batch + for(int b = 0; b < batch_size; ++b) + { + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, seqlen_q, seqlen_kv)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + } + + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); + + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); + } +} +} // namespace host + +template +bool run_impl(const Problem& problem, const RunConfig& run_config) +{ + auto [q, k, v] = generate_qkv(problem, run_config.seed); + + ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); + /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v + ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q.data()); + k_buf.ToDevice(k.data()); + v_buf.ToDevice(v.data()); + + ck_tile::fmha_fwd_v3_args args; + + args.data_type = problem.data_type; + args.batch = problem.batch; + args.seqlen_q = problem.seqlen_q; + args.seqlen_k = problem.seqlen_k; + args.nhead_q = problem.nhead_q; + args.nhead_kv = problem.nhead_kv; + args.hdim_qk = problem.hdim; + args.hdim_v = problem.hdim; + args.softmax_scale = problem.softmax_scale; + + args.window_size_left = problem.mask.left; + args.window_size_right = problem.mask.right; + args.mask_type = static_cast(problem.mask.type); + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.q_ptr = q_buf.GetDeviceBuffer(); + args.stride_q = + problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_q = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; + args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.k_ptr = k_buf.GetDeviceBuffer(); + args.stride_k = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_k = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.v_ptr = v_buf.GetDeviceBuffer(); + args.stride_v = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_v = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.o_ptr = o_buf.GetDeviceBuffer(); + args.stride_o = + problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_o = problem.output_layout == TensorLayout::bshd + ? problem.hdim + : problem.seqlen_q * problem.hdim; + args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + + ck_tile::stream_config stream_config{nullptr, + true, + /*log_level=*/0, + run_config.kernel_warmup, + run_config.kernel_repeat}; + + auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); + if(!result) + { + std::cerr << "faild to run fmha_fwd_v3()" << std::endl; + return false; + } + + std::size_t flop = [&] { + if(problem.mask.type == mask_enum::no_mask) + { + return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + else + { + /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + }(); + float tflops = static_cast(flop) / 1.e9 / time; + + std::cout << "[" << problem.data_type << "|"; + if(problem.input_layout == problem.output_layout) + { + std::cout << problem.input_layout; + } + else + { + std::cout << problem.input_layout << "-" << problem.output_layout; + } + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim + << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed + << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + << " TFlops" << std::endl; + + if(!run_config.verify) + { + return true; + } + + // transpose tensor descriptors from bhsd to bshd if necessary + if(problem.input_layout != TensorLayout::bshd) + { + q = q.transpose({0, 2, 1, 3}); + k = k.transpose({0, 2, 1, 3}); + v = v.transpose({0, 2, 1, 3}); + } + + ck_tile::HostTensor o_ref(problem.get_output_shape()); + if(problem.output_layout != TensorLayout::bshd) + { + o_ref = o_ref.transpose({0, 2, 1, 3}); + } + + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); + + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); +} + +int main(int argc, char* argv[]) +{ + auto [parse_result, args] = parse_cmd_args(argc, argv); + if(!parse_result) + { + std::cerr << "failed to parse command line arguments" << std::endl; + } + + Problem problem(args); + RunConfig run_config(args); + + const auto run = [&] { + if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) + { + return run_impl(problem, run_config); + } + else + { + return run_impl(problem, run_config); + } + }; + + return !run(); +} diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp new file mode 100644 index 0000000000..30019167fb --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" +#include "mask.hpp" + +namespace ck_tile { + +std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) +{ + switch(data_type) + { + case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; + case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; + default: return stream << "unknown"; + } +} + +std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) +{ + if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + } + else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + } + + return std::make_pair(false, -1.f); +} + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp new file mode 100644 index 0000000000..5361d27f0f --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/stream_config.hpp" + +namespace ck_tile { + +struct fmha_fwd_v3_args +{ + enum class data_type_enum + { + fp16, + bf16 + }; + + data_type_enum data_type; + // bool is_varlen; + + index_t batch; + index_t seqlen_q; + index_t seqlen_k; + index_t nhead_q; + index_t nhead_kv; + index_t hdim_qk; + index_t hdim_v; + + float softmax_scale; + + index_t window_size_left; + index_t window_size_right; + index_t mask_type; + + const void* q_ptr; + index_t stride_q; + index_t nhead_stride_q; + index_t batch_stride_q; + + const void* k_ptr; + index_t stride_k; + index_t nhead_stride_k; + index_t batch_stride_k; + + const void* v_ptr; + index_t stride_v; + index_t nhead_stride_v; + index_t batch_stride_v; + + void* o_ptr; + index_t stride_o; + index_t nhead_stride_o; + index_t batch_stride_o; +}; + +std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp new file mode 100644 index 0000000000..d6e4ac4c60 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" + +#include "fmha_fwd_v3.hpp" + +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ + template <> \ + std::pair fmha_fwd_v3_kernel_dispatch( \ + const fmha_fwd_v3_args& args, const stream_config& config) \ + { \ + return std::make_pair(true, \ + fmha_fwd_v3_kernel_launch(args, config)); \ + } + +namespace ck_tile { + +template +struct fmha_fwd_v3_problem_traits; + +template <> +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::bf16_t; + using acc_dtype = float; + using o_dtype = ck_tile::bf16_t; + using lse_dtype = float; +}; + +template +struct fmha_fwd_v3_kernel_traits +{ + static constexpr auto date_type = DataType; + static constexpr bool is_variable_seqlen = IsVariableSeqlen; + static constexpr bool is_masking = IsMasking; + + // M0 N0 K0 N1 K1 + using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; + using fmha_warp_gemm_shape = sequence<32, 32, 16>; + using fmha_block_warps = sequence<8, 1, 1>; + + using fmha_shape = TileFmhaShape; + + using fmha_traits = TileFmhaFwdV3Traits; + + using fmha_mask = SimplifiedGenericAttentionMask; + + using fmha_pipeline_problem = + BlockFmhaFwdV3PipelineProblem::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::lse_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, + fmha_shape, + IsVariableSeqlen, + fmha_mask, + fmha_traits>; + + using fmha_pipeline = BlockFmhaFwdV3Pipeline; + + using epilogue = Default2DEpilogue< + Default2DEpilogueProblem::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, + true, // kPadM + true, // kPadM + true // UseRawStore + >>; + + using kernel = FmhaFwdV3Kernel; +}; + +template +float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) +{ + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_qk, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_kv, + args.softmax_scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; + + return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +template +std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, + const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp new file mode 100644 index 0000000000..2dbe0b2098 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp new file mode 100644 index 0000000000..6f5eca97a1 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp new file mode 100644 index 0000000000..1c4c798af6 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp new file mode 100644 index 0000000000..077cb7b73c --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh new file mode 100755 index 0000000000..9c500edf9d --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)" +VALID=0 + +for causal in 0 1 ; do +for prec in "fp16" "bf16" ; do +for hdim in 128 ; do +for perm in 0 ; do + +if [ $causal -eq 0 ]; then + mask=0 +else + mask=b:-1,0 +fi + +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID + +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID + +done +done +done +done diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 16fde15c7b..31de21a726 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -40,6 +41,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp new file mode 100644 index 0000000000..be14a36353 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -0,0 +1,519 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" + +#include +#include + +namespace ck_tile { + +template +struct FmhaFwdV3Kernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for lse + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for lse + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + using namespace ck_tile; + + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using namespace ck_tile; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {0, i_n1}); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + kargs.scale_s, + smem_ptr); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp new file mode 100644 index 0000000000..20d84116d4 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -0,0 +1,1198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +#define ENABLE_ASM_MARKER 1 +#if ENABLE_ASM_MARKER +#define ASM_MARKER(marker) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " #marker); \ + __builtin_amdgcn_sched_barrier(0); +#else +#define ASM_MARKER(marker) +#endif + +#define ADD_SBARRIER_FOR_PHASE0 1 +#if !defined(CK_TILE_DISABLE_PACKED_FP32) +#define CK_TILE_DISABLE_PACKED_FP32 0 +#endif + +#define WARP_ID 0 +#define LANE_ID 0 + +#define ENABLE_DEBUG_STMTS 1 +#if ENABLE_DEBUG_STMTS +#define DEBUG_STMTS \ + if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) +#else +#define DEBUG_STMTS if constexpr(false) +#endif + +namespace ck_tile { + +template +struct CoreLoopScheduler; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) {} + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) {} + } + else + { + if constexpr(Phase == 0) {} + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) {} + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) {} + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) {} + } + else + { + if constexpr(Phase == 0) {} + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) {} + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +namespace detail { +CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) +{ +#if CK_TILE_DISABLE_PACKED_FP32 + return a * b + c; +#else + float result; + asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" + : [result] "=v"(result) + : [a] "v"(a), [b] "s"(b), [c] "v"(c)); + return result; +#endif +} + +CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) +{ + fp16x2_t result; + asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b) +{ + bf16x2_t result; + asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) +{ + fp32x2_t result; + asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} +} // namespace detail + +template +struct BlockFmhaFwdV3Pipeline +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + + static_assert(std::is_same_v, + "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); + + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr ck_tile::index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr ck_tile::index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr ck_tile::index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr ck_tile::index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr ck_tile::index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + return 2; + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // create another LDS buffer for p + return ck_tile::max(kM0 * kN1 * sizeof(PDataType), + Policy::template GetSmemSize() + + kM0 * kN0 * sizeof(PDataType)); + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() + { + using namespace ck_tile; + constexpr auto lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + return lds_block_desc; + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() + { + using namespace ck_tile; + constexpr auto lds_block_desc = make_naive_tensor_descriptor( + make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); + + return lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc) + { + using namespace ck_tile; + + auto tensor_view = + make_tensor_view(reinterpret_cast(base), desc); + return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); + } + + // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 + template + CK_TILE_DEVICE static constexpr void s_waitcnt() + { + // vmcnt use bits {[15:14],[3:0]} + // expcnt use bits [6:4] + // lgkmcnt use bits [11:8] + __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | + ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() + { + s_waitcnt(); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() + { + s_waitcnt<63, Lgkmcnt>(); + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + auto s_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + + auto p_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + + auto o_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto o_lds_window = + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + + auto m_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc1D()); + [[maybe_unused]] auto m_lds_window = + make_tile_window(m_lds, make_tuple(number{}), {0}); + + const index_t warp_group_id = get_warp_id() / 4; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + auto q_dram_window = make_tile_window_linear( + q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto k_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + auto v_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + statically_indexed_array( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> + k_lds_window_load; + + statically_indexed_array( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> + v_lds_window_load; + + decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())) q_tile; + + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; + + union sp_compute_type + { + CK_TILE_DEVICE sp_compute_type() {} + + decltype(gemm_0.MakeCBlockTile()) sp_compute; + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())) p; + }; + statically_indexed_array sp; + + decltype(gemm_1.MakeCBlockTile()) o_acc; + constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd() + // instructions should we move to fmha_alu1() + static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + + decltype(block_tile_reduce( + sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; + decltype(m) l; + + // initialize k_lds_window and v_lds_window + static_for<0, 2, 1>{}([&](auto idx) { + k_lds_window_load(idx) = make_tile_window( + make_lds_tile_window( + static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); + }); + + static_for<0, 2, 1>{}([&](auto idx) { + v_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + static_cast(smem_ptr) + + (idx + 2) * Policy::template GetSmemSizeKV(), + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution()); + }); + + { + auto origin_q = load_tile(q_dram_window); + auto transformed_q = tile_elementwise_in(q_element_func, origin_q); + + q_tile = transformed_q; + } + + clear_tile(o_acc); + set_tile(m, bit_cast(0xff7fffff)); // a bit larger than -infinity + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + index_t kv_token_start = seqlen_k_start; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + k_dram_window.init_raw(); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + v_dram_window.init_raw(); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); + static_assert(kN0 == kK1); + + constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + static_assert(NumWarpGroups == 2); + + [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { + printf("[POYENC] %s (size=%d): %5.2f", + name, + decltype(dist_tensor.thread_buf_)::size(), + ck_tile::type_convert(dist_tensor.thread_buf_[0])); + static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { + printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); + }); + printf("\n"); + }; + + [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { + const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); + const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + if constexpr(true || num_rows < num_cols) + { + for(int row = 0; row < num_rows; ++row) + { + int offset = desc.calculate_offset(make_tuple(row, 0)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + row, + ck_tile::type_convert(data[offset])); + for(int col = 1; col < num_cols; ++col) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + else + { + for(int col = 0; col < num_cols; ++col) + { + int offset = desc.calculate_offset(make_tuple(0, col)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + col, + ck_tile::type_convert(data[offset])); + for(int row = 1; row < num_rows; ++row) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + }; + + [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { + const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + int offset = desc.calculate_offset(make_tuple(0)); + printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); + for(int e = 1; e < num_elems; ++e) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(e)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + }; + + // K_mem_su_ld_insts = 1 for 32 x 128 + // V_mem_su_ld_insts = 1 for 128 x 32 + static constexpr int K_mem_su_ld_insts = 1; + static constexpr int V_mem_su_ld_insts = 1; + + auto K_mem_load = [&](auto k_lds_write_idx) { + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + + /// FIXME: use the future-predicting method to move the window + // move K tile windows + move_tile_window(k_dram_window, {kN0, 0}); + }; + + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + }; + + auto V_mem_load = [&](auto v_lds_write_idx) { + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + __builtin_amdgcn_sched_barrier(0); + + /// FIXME: use the future-predicting method to move the window + move_tile_window(v_dram_window, {kK1, 0}); + }; + + auto V_lds_load = [&](auto v_lds_read_idx) { + kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + }; + + decltype(m) m_old; + SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + /// TODO: remove the sp_delta and use sp_compute directly + statically_indexed_array{}).sp_compute), 2> sp_delta; + + auto fmha_alu0 = [&](auto sp_reg_idx) { + m_old = m; // m{j-1} + static_assert(m.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowmax value"); + auto m_latest = block_tile_reduce( + sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), + bit_cast(m_latest.thread_buf_[0]), + false, + false); + /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler + m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(m_latest, f_max, bool_constant{}); +#endif + m = m_latest; + + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + }); + }); + /// TODO: move some fmha_alu1() code here if necessary + }; + + auto fmha_alu1 = [&](auto sp_reg_idx) { + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); + }); + }); + + auto rowsum_p = block_tile_reduce( + sp(sp_reg_idx).sp_compute, + sequence<1>{}, + f_sum, + SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + static_assert(rowsum_p.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowsum value"); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), + bit_cast(rowsum_p.thread_buf_[0]), + false, + false); + rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#endif + // update partial o_acc [0, 2) + static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + + // l{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + + l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); + }); + + // update partial o_acc [2, fmha_alu_D_reg_cnt) + static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + + /// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite + /// the cast_tile() call into inline asm to force the conversion instructions to be + /// generated here. The fmha_alu1() call should be placed at the end of a phase. + static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); + float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + if constexpr(std::is_same_v) + { + auto casted = detail::cvt_pk_fp16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + else + { + auto casted = detail::cvt_pk_bf16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + }); + }; + + auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + } + }; + + auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + fmha_alu0(number<1>{} - sp_reg_idx); + } + }; + + auto fmha_alu_D_upd = [&] { + o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + + fp32x2_t pk_o_acc_scale; + pk_o_acc_scale.x = o_acc_scale; + pk_o_acc_scale.y = o_acc_scale; + + static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); +#if CK_TILE_DISABLE_PACKED_FP32 + static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); +#endif + + constexpr auto issued_D_reg_cnt = +#if CK_TILE_DISABLE_PACKED_FP32 + fmha_alu_D_reg_cnt + 2 +#else + fmha_alu_D_reg_cnt +#endif + ; + /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call + /// should be placed at the end of a phase. + // update partial o_acc after [issued_D_reg_cnt] + static_for{}([&](auto idx) { + fp32x2_t input; + input.x = o_acc.thread_buf_[idx]; + input.y = o_acc.thread_buf_[idx + 1]; + + auto output = detail::pk_mul_f32(input, pk_o_acc_scale); + + o_acc.thread_buf_[idx] = output.x; + o_acc.thread_buf_[idx + 1] = output.y; + }); + }; + + auto fmha_mask = [&](auto sp_reg_idx) { + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(sp(sp_reg_idx).sp_compute, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = kv_token_start + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + }; + + auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) { + if constexpr(load_type == 0) + { + V_mem_load(mem_wr_idx); + K_lds_load(lds_rd_idx); + } + else + { + K_mem_load(mem_wr_idx); + V_lds_load(lds_rd_idx); + } + }; + + auto core_loop = [&](auto cl_p) { + auto gemm0 = number<0>{}; + auto gemm1 = number<1>{}; + + auto memV = number<0>{}; + auto memK = number<1>{}; + + using Scheduler = CoreLoopScheduler; + + auto iteration = [&](auto pi) { + auto xdl_SP_p01_reg_idx = number<1>{} - pi; + auto xdl_SP_p23_reg_idx = pi; + + auto K_w0_lds_wr_idx = number<1>{} - pi; + auto V_w0_lds_wr_idx = pi; + auto K_w0_lds_rd_idx = pi; + auto V_w0_lds_rd_idx = pi; + + auto K_w4_lds_wr_idx = number<1>{} - pi; + auto V_w4_lds_wr_idx = number<1>{} - pi; + auto K_w4_lds_rd_idx = number<1>{} - pi; + auto V_w4_lds_rd_idx = pi; + + bool result = true; + + if constexpr(cl_p == 0) + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave0-3 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave0-3 (pi=1)"); + } + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + fmha_mask(xdl_SP_p01_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave0-3"); + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<2>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); + + Scheduler::schedule(cl_p, number<3>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + } + else + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave4-7 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave4-7 (pi=1)"); + } + cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave4-7"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + fmha_mask(xdl_SP_p01_reg_idx); + + Scheduler::schedule(cl_p, number<2>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<3>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + } + return result; + }; + return iteration(number<0>{}) && iteration(number<1>{}); + }; + + auto fmha_post_process = [&](auto d) { + auto ps_pi = number<1>{} - d; + auto V_lds_rd_idx = ps_pi; + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + + V_lds_load(V_lds_rd_idx); + fmha_alu1(ps_pi); + + s_waitcnt_lgkmcnt<0>(); + + auto xdl_SP_p23_reg_idx = ps_pi; + gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); + }; + + // pre-stage + { + ASM_MARKER("before pre-stage"); + // (1) load K0 to LDS & VGPR + K_mem_load(number<0>{}); // mem_K0 + + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + K_lds_load(number<0>{}); // lds_K0 + + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 + if(1 < num_total_loop) + { + K_mem_load(number<1>{}); // mem_K1 + } + V_mem_load(number<0>{}); // mem_V0 + + // (3) mfma (Q*K0) + softmax + gemm(number<0>{}, /*gemm_idx=*/number<0>{}); + + fmha_mask(number<0>{}); + /// TODO: find better way to map fmha_alu(0,96) call + fmha_alu0(number<0>{}); + fmha_alu_D_upd(); + + kv_token_start += kN0; + ++i_total_loops; + if(num_total_loop <= i_total_loops) + { + goto label_main_loops_exit; + } + + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } + + ASM_MARKER("end pre-stage"); + } + + if(1 < num_total_loop) + { + if(warp_group_id == 0) + { + V_mem_load(number<1>{}); // V1 + K_lds_load(number<1>{}); // K1 + + asm volatile("s_setprio 0"); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<0>{})) + ; + } + if(warp_group_id != 0) + { + asm volatile("s_setprio 1"); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<1>{})) + ; + } + } + label_main_loops_exit: + if(num_total_loop % 2) + { + fmha_post_process(number<1>{}); + } + if(!(num_total_loop % 2)) + { + fmha_post_process(number<0>{}); + } + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp new file mode 100644 index 0000000000..e440280d7e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" + +namespace ck_tile { + +struct BlockFmhaV3PipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 4; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentK() + { + using namespace ck_tile; + using KDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(KDataType); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentV() + { + using namespace ck_tile; + using VDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentV(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + // compute the endcoding before transpose + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto GetQKBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + + return BlockGemmARegBRegCRegV2{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetPVBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass + /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single + using WarpGemm = WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + return BlockGemmARegBRegCRegV2{}; + } + + static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + + template + CK_TILE_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return v_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + using namespace ck_tile; + + static_assert(MakeKLdsLoadBlockDescriptor().get_element_space_size() == + MakeKLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t k_element_space_size = + MakeKLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(MakeVLdsLoadBlockDescriptor().get_element_space_size() == + MakeVLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t v_element_space_size = + MakeVLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <= + GetSingleSmemElementSpaceSize()); + + /// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() & + /// MakeVLdsBlockDescriptor() + static_assert(std::is_same_v); + constexpr index_t kv_element_space_size_in_bytes = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return kv_element_space_size_in_bytes; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 4 * GetSmemSizeKV(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 86ac713b6f..7775848195 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -262,4 +263,47 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdV3PipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index fb4713ccc0..cd3893f5cf 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -148,4 +148,20 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaFwdV3Traits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile