From 26840b623a94f2477604a94009a796bb6d4717f7 Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Wed, 5 Jun 2024 02:12:45 +0800 Subject: [PATCH] CK Tile FA Training kernels (#1286) * FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 Co-authored-by: carlushuang Co-authored-by: rocking Co-authored-by: Po Yen Chen Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 2cab8d39e3c638d890774e43cc4ba6bcb982c54a] --- example/ck_tile/01_fmha/CMakeLists.txt | 39 +- example/ck_tile/01_fmha/fmha_bwd.cpp | 932 +++++++++++ example/ck_tile/01_fmha/fmha_bwd.hpp | 359 +++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 135 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 114 +- example/ck_tile/01_fmha/generate.py | 702 +++++++- .../ck_tile/01_fmha/script/benchmark_bwd.sh | 21 + .../script/{benchmark.sh => benchmark_fwd.sh} | 0 .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 33 + .../{smoke_test.sh => smoke_test_fwd.sh} | 22 +- include/ck_tile/core.hpp | 3 + .../core/arch/amd_buffer_addressing.hpp | 74 +- .../core/arch/generic_memory_space_atomic.hpp | 175 ++ include/ck_tile/core/numeric/vector_type.hpp | 11 +- include/ck_tile/core/tensor/buffer_view.hpp | 13 +- include/ck_tile/core/tensor/store_tile.hpp | 2 +- include/ck_tile/core/tensor/tensor_view.hpp | 30 +- .../ck_tile/core/tensor/tile_distribution.hpp | 1 + include/ck_tile/core/tensor/tile_window.hpp | 60 + include/ck_tile/core/tensor/update_tile.hpp | 55 + include/ck_tile/core/utility/philox_rand.hpp | 89 ++ include/ck_tile/host.hpp | 1 + include/ck_tile/host/host_tensor.hpp | 36 +- .../reference/reference_batched_dropout.hpp | 33 + include/ck_tile/ops/fmha.hpp | 14 + .../ck_tile/ops/fmha/block/block_dropout.hpp | 329 ++++ .../ck_tile/ops/fmha/block/block_masking.hpp | 72 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 1421 +++++++++++++++++ .../fmha/kernel/fmha_bwd_tile_partitioner.hpp | 54 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 180 ++- .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 2 +- .../fmha/pipeline/block_fmha_bwd_dot_do_o.hpp | 95 ++ ...block_fmha_bwd_dot_do_o_default_policy.hpp | 20 + ...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp | 848 ++++++++++ ...k_dv_pipeline_ks_kts_vr_default_policy.hpp | 20 + ...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp | 821 ++++++++++ ...dq_dk_dv_pipeline_ks_vr_default_policy.hpp | 20 + ...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp | 692 ++++++++ ...v_pipeline_qs_ks_vr_dos_default_policy.hpp | 20 + ...block_fmha_bwd_pipeline_default_policy.hpp | 1343 ++++++++++++++++ .../pipeline/block_fmha_bwd_pipeline_enum.hpp | 16 + .../block_fmha_bwd_pipeline_problem.hpp | 91 ++ .../pipeline/block_fmha_pipeline_problem.hpp | 29 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 54 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 59 +- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 45 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 29 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 46 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 49 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 +- include/ck_tile/ops/gemm.hpp | 7 +- .../block_gemm_areg_bgmem_creg_problem.hpp | 25 - .../block/block_gemm_areg_bgmem_creg_v1.hpp | 4 +- ...gemm_areg_bgmem_creg_v1_default_policy.hpp | 2 +- .../block/block_gemm_areg_bsmem_creg_v1.hpp | 161 +- ..._gemm_areg_bsmem_creg_v1_custom_policy.hpp | 2 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 2 +- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 2 +- ..._gemm_areg_bsmem_creg_v2_custom_policy.hpp | 2 +- ...gemm_areg_bsmem_creg_v2_default_policy.hpp | 2 +- .../block/block_gemm_asmem_breg_creg_v1.hpp | 228 +++ ..._gemm_asmem_breg_creg_v1_custom_policy.hpp | 36 + ...gemm_asmem_breg_creg_v1_default_policy.hpp | 56 + .../block_gemm_asmem_bsmem_creg_problem.hpp | 26 - .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 2 +- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 2 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 2 +- ...reg_problem.hpp => block_gemm_problem.hpp} | 4 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 88 + 70 files changed, 9506 insertions(+), 482 deletions(-) create mode 100644 example/ck_tile/01_fmha/fmha_bwd.cpp create mode 100644 example/ck_tile/01_fmha/fmha_bwd.hpp create mode 100644 example/ck_tile/01_fmha/script/benchmark_bwd.sh rename example/ck_tile/01_fmha/script/{benchmark.sh => benchmark_fwd.sh} (100%) create mode 100644 example/ck_tile/01_fmha/script/smoke_test_bwd.sh rename example/ck_tile/01_fmha/script/{smoke_test.sh => smoke_test_fwd.sh} (57%) create mode 100644 include/ck_tile/core/arch/generic_memory_space_atomic.hpp create mode 100644 include/ck_tile/core/tensor/update_tile.hpp create mode 100644 include/ck_tile/core/utility/philox_rand.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_dropout.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_dropout.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp delete mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp delete mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp rename include/ck_tile/ops/gemm/block/{block_gemm_areg_bsmem_creg_problem.hpp => block_gemm_problem.hpp} (88%) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 85d25c63d9..e324f85ed8 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,17 +1,29 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt + --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) -# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") @@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_BWD}") +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) @@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) +set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp new file mode 100644 index 0000000000..b1249b5eda --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -0,0 +1,932 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_bwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k < 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + if(hdim_q % 2 != 0 || hdim_v % 2 != 0) + { + std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; + return false; + } + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) + { + std::cerr << "dbias only exists when bias type is elementwise" << std::endl; + return false; + } + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaBwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using GemmDataType = typename TypeConfig::GemmDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using AccDataType = typename TypeConfig::AccDataType; + using DDataType = typename TypeConfig::DDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using ODataType = typename TypeConfig::ODataType; + using OGradDataType = typename TypeConfig::OGradDataType; + using QGradDataType = typename TypeConfig::QGradDataType; + using KGradDataType = typename TypeConfig::KGradDataType; + using VGradDataType = typename TypeConfig::VGradDataType; + using BiasGradDataType = typename TypeConfig::BiasGradDataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(3) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + static_cast(2) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * real_seqlen_k * hdim_v + + sizeof(ODataType) * real_seqlen_q * hdim_v + + sizeof(OGradDataType) * real_seqlen_q * hdim_v + + sizeof(QGradDataType) * real_seqlen_q * hdim_q + + sizeof(KGradDataType) * real_seqlen_k * hdim_q + + sizeof(VGradDataType) * real_seqlen_k * hdim_v + + sizeof(LSEDataType) * real_seqlen_q); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor lse_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor d_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor dq_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor dk_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor dv_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor do_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor dbias_host( + use_dbias + ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + if(init_method == 0) + { + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 2) + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(do_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + do_buf.ToDevice(do_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask + << std::flush; + + auto fmha_traits = fmha_bwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck_tile::index_t stride_bias = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_bias = 0; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; + const ck_tile::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + + return fmha_bwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + do_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + dq_buf.GetDeviceBuffer(), + dk_buf.GetDeviceBuffer(), + dv_buf.GetDeviceBuffer(), + dbias_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, + stride_o, + stride_randval, + stride_do, + stride_dk, + stride_dv, + stride_dbias, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + nhead_stride_randval, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dbias, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + batch_stride_randval, + batch_stride_do, + batch_stride_lsed, + batch_stride_dk, + batch_stride_dv, + batch_stride_dbias, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + p_undrop, + s_randval, + {drop_seed, drop_offset}}; + }(); + + float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + bool pass = true; + + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; + + randval_buf.FromDevice(randval_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + ck_tile::HostTensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + ck_tile::HostTensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + + // reference + // S = scale * Q * K^T + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + + if(p_drop > 0) + { + p_hp_host_ref.ForEach( + [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + else + { + p_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + // O = P * V + ck_tile::reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + + // clang-format off + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); + // clang-format on + + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } + } + + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + fmha_bwd(fmha_traits, fmha_args, stream_config_v); + + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + ck_tile::HostTensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + ck_tile::HostTensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + + // clang-format off + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + ck_tile::reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + + if(p_drop > 0) + { + ck_tile::reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } + + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + auto idx_gmo = idx_gmn; + idx_gmo[2] = o; + do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * + ck_tile::type_convert(o_host_refs[wb](idx_gmo)); + } + self(idx_gmn) = ck_tile::type_convert( + p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); + }); + + if(use_dbias) + { + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + ck_tile::reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + ck_tile::reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + ck_tile::reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + + // clang-format off + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + } + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); + + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp new file mode 100644 index 0000000000..0c6b468951 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include "bias.hpp" +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_lsed, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 91fc07d831..5f887f0655 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -110,6 +110,9 @@ auto create_args(int argc, char* argv[]) "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -128,26 +131,11 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else if(init_method == "nf") - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else - { - double rtol = 3e-3; - double atol = 3e-3; - return ck_tile::make_tuple(rtol, atol); - } + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); } template <> @@ -250,6 +238,21 @@ bool run(const ck_tile::ArgParser& arg_parser) mask_info mask = mask_info::decode( arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); if(*seed == 0) @@ -274,21 +277,23 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = FmhaFwdTypeConfig; - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -300,6 +305,11 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_q = real_seqlen_q; } + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); @@ -353,12 +363,16 @@ bool run(const ck_tile::ArgParser& arg_parser) // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); @@ -434,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); @@ -463,8 +478,8 @@ bool run(const ck_tile::ArgParser& arg_parser) << (seqlen_kpads[0] < 0 ? "" : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias - << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout - << std::flush; + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -474,6 +489,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, bias.type, lse, + p_drop > 0.0f, squant}; auto p_compute_element_func = [&]() { @@ -505,8 +521,9 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -518,21 +535,24 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); - const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() : bias_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -554,22 +574,28 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_v, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, + stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_bias, + batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, - static_cast(mask.type)}; + static_cast(mask.type), + p_drop, + s_randval, + {drop_seed, drop_offset}}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); @@ -596,6 +622,11 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; bool pass = true; @@ -771,6 +802,17 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, p_host_ref, p_compute_element_func); } + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, @@ -804,9 +846,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b, idx[0], idx[1] + query_offset); - }); + lse_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); bool lse_pass = ck_tile::check_err(lse_host_result, lse_host_ref, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fb3907fec1..3594f61db9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; }; struct FmhaMasks @@ -88,6 +92,7 @@ struct fmha_fwd_args const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -108,22 +113,28 @@ struct fmha_fwd_args ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; }; template @@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, @@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqlen_k_ptr, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_lse, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, + args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); @@ -222,6 +249,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; @@ -264,6 +293,7 @@ struct fmha_fwd_traits mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; + bool has_dropout; bool do_fp8_static_quant; // TODO: padding check is inside this api }; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index f0180d6db7..e0b4b65595 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -83,7 +83,6 @@ TILE_PARTITIONER_MAP = { "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", } -DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -111,8 +110,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_dpad}, {F_dvpad}, {F_bias}, + false, {F_lse}, - {F_squant}, + {F_dropout}, + {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; @@ -123,6 +124,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, @@ -146,7 +148,7 @@ using fmha_kernel_{F_idx} = fmha_epilogue_{F_idx}>; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -191,9 +193,9 @@ MASK_SIMPLIFIED_CHECK_MAP = { "s_mask" : "t.mask_type != mask_enum::no_mask", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -233,6 +235,7 @@ class FmhaFwdApiTrait: mask : str bias : str # lse : str # + dropout : str squant : str # spad : str skpad : str @@ -242,7 +245,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -299,6 +302,7 @@ class FmhaFwdPipeline: F_dvpad : str # F_bias : str # true/false F_lse : str # + F_dropout : str # F_squant : str # F_mask : str # value from MASK_MAP @@ -321,6 +325,7 @@ class FmhaFwdPipeline: else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' return n @@ -351,7 +356,7 @@ class FmhaFwdApiPool: inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -365,7 +370,7 @@ class FmhaFwdApiPool: @dataclass class FmhaFwdTileSize: F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along qk seqlen + F_bn0 : int # tile size along k seqlen F_bk0 : int # tile size along qk gemm unroll F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll @@ -424,9 +429,10 @@ class FmhaFwdKernel: F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -461,6 +467,7 @@ class FmhaFwdKernel: mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, @@ -489,7 +496,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -500,26 +507,26 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse kernels + # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) else: assert False return pipelines @@ -527,7 +534,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) if d == None: continue @@ -551,44 +558,660 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) return (api_pool, gen) -def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_dk_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, false>>; + +using fmha_bwd_dv_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, false>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_dk_epilogue_{F_idx}, + fmha_bwd_dv_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'bwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_api(api_pool, output_dir) + if direction == 'fwd': + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + if direction == 'fwd': + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", description="gen api for CK fmha kernel", ) + parser.add_argument( + "-d", + "--direction", + default='fwd', + choices=['fwd', 'bwd'], + required=False, + help="choose the direction of kernels(default: fwd)" + ) parser.add_argument( "-o", "--output_dir", @@ -623,11 +1246,12 @@ if __name__ == "__main__": default=0, required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim" + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration" ) args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) + list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) + write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh new file mode 100644 index 0000000000..7591f5442a --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh similarity index 100% rename from example/ck_tile/01_fmha/script/benchmark.sh rename to example/ck_tile/01_fmha/script/benchmark_fwd.sh diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh new file mode 100644 index 0000000000..9fe795471b --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -0,0 +1,33 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1' + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do +for mode in 0 1 ; do +for bias in "n" "e" "a"; do +for dbias in 0 1 ; do +for p_drop in 0.0 0.2; do + +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh similarity index 57% rename from example/ck_tile/01_fmha/script/smoke_test.sh rename to example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 21f679e11a..63813e079c 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -17,17 +17,18 @@ for perm in 0 1 ; do for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do -for bias in "n" "e" "a"; do +for bias in "n" "e" "a" ; do +for p_drop in 0.0 0.2; do -# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done @@ -37,6 +38,7 @@ done done done done +done for perm in 0 1 ; do for bias in "n" "e" "a" ; do diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb19c9154b..bb490cce4a 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" @@ -47,10 +48,12 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index ac2f0cab9c..9c6e85f013 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); +// buffer store ui16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, @@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 @@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d static_cast(coherence)); } } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(uint16_t), + static_cast(coherence)); + } + } else { using r_t = thread_buffer; @@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_th { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp new file mode 100644 index 0000000000..6212db9169 --- /dev/null +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" + +namespace ck_tile { + +CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) +{ + return type_convert(type_convert(a) + type_convert(b)); +} + +CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) +{ + bf16x2_t rtn; + rtn[0] = add_bf16_t(a[0], b[0]); + rtn[1] = add_bf16_t(a[1], b[1]); + return rtn; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); + +template <> +CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) +{ + union U32BF162_ADDR + { + uint32_t* u32_a; + bf16x2_t* bf162_a; + }; + + union U32BF162 + { + uint32_t u32; + bf16x2_t bf162; + }; + + U32BF162_ADDR dword_addr; + U32BF162 cur_v; + U32BF162 new_; + uint32_t old_v, new_v; + dword_addr.bf162_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf162 = add_bf16x2_t(cur_v.bf162, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template +CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + return atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 2) + { + atomic_add(c_style_pointer_cast(p_dst), bit_cast(x)); + } + else if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } +} + +template +CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicMax(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicMax(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 85d9be1c94..c23c12f295 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); +// ui8 +// using uint8_t +using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); +using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); +using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); +using uint8x16_t = uint8_t __attribute((ext_vector_type(16))); +using uint8x32_t = uint8_t __attribute((ext_vector_type(32))); +using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); + #if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 96b38241c0..ffe8f7a4fd 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" @@ -507,10 +508,10 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } @@ -518,7 +519,7 @@ struct buffer_view(c_style_pointer_cast(&p_data_[i]), x); + atomic_add_g, t_per_x>(&p_data_[i], x); } } } @@ -547,16 +548,16 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } else if(is_valid_element) { - atomic_max(c_style_pointer_cast(&p_data_[i]), x); + atomic_max_g, t_per_x>(&p_data_[i], x); } } diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index c12ad883d9..2efc657013 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index e37bd806de..656309532e 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -16,7 +16,9 @@ namespace ck_tile { -template +template struct tensor_view { using buffer_view = remove_reference_t; @@ -24,6 +26,7 @@ struct tensor_view using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + static constexpr auto DstInMemOp = DstInMemOp_; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -140,6 +143,23 @@ struct tensor_view x); } + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template update( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); @@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, } template (p, desc.get_element_space_size()); - return tensor_view{buffer_view, desc}; + return tensor_view{buffer_view, desc}; } template >{ - old_tensor_view.buf_, new_desc}; + return tensor_view, + remove_cvref_t::DstInMemOp>{old_tensor_view.buf_, new_desc}; } template + CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp new file mode 100644 index 0000000000..fbce7c4083 --- /dev/null +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.update(dstr_tensor); +} + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.update(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp new file mode 100644 index 0000000000..c49f44ae48 --- /dev/null +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh +class philox +{ + public: + CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_) + : seed(reinterpret_cast(seed_)) + { + + ull2* tmp = reinterpret_cast(&counter); + tmp->x = offset_; + } + + CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const + { + + uint4 counter_ = counter; + ull2* tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + + uint2 key_ = seed; +// 7-round philox +#pragma unroll + for(int i = 0; i < 6; i++) + { + counter_ = philox_single_round(counter_, key_); + key_.x += kPhilox10A; + key_.y += kPhilox10B; + } + uint4 output = philox_single_round(counter_, key_); + return output; + } + + CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out, + const unsigned long long subsequence) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32_t* out_tmp = reinterpret_cast(&out[0]); + + out_tmp[0] = tmp_ph.x; + out_tmp[1] = tmp_ph.y; + out_tmp[2] = tmp_ph.z; + out_tmp[3] = tmp_ph.w; + } + + private: + struct ull2 + { + uint64_t x; + uint64_t y; + }; + uint4 counter; + const uint2 seed; + + CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const + { + uint2* res; + unsigned long long tmp; + tmp = static_cast(a) * b; + res = reinterpret_cast(&tmp); + return *res; + } + + CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const + { + + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 98a3bb7d7f..09030fa6df 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,6 +11,7 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" +#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index cd0dc38259..43405ee69b 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,7 +156,7 @@ struct HostTensorDescriptor } const std::vector& get_lengths() const { return mLens; } - const std::vector& GetStrides() const { return mStrides; } + const std::vector& get_strides() const { return mStrides; } template std::size_t GetOffsetFromMultiIndex(Is... is) const @@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old for(std::size_t i = 0; i < a.get_num_of_dimension(); i++) { new_lengths[i] = a.get_lengths()[new2old[i]]; - new_strides[i] = a.GetStrides()[new2old[i]]; + new_strides[i] = a.get_strides()[new2old[i]]; } return HostTensorDescriptor(new_lengths, new_strides); @@ -327,7 +327,7 @@ struct HostTensor decltype(auto) get_lengths() const { return mDesc.get_lengths(); } - decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + decltype(auto) get_strides() const { return mDesc.get_strides(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } @@ -481,6 +481,34 @@ struct HostTensor return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } + HostTensor transpose(std::vector axes = {}) const + { + if(axes.empty()) + { + axes.resize(this->get_num_of_dimension()); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if(axes.size() != mDesc.get_num_of_dimension()) + { + throw std::runtime_error( + "HostTensor::transpose(): size of axes must match tensor dimension"); + } + std::vector tlengths, tstrides; + for(const auto& axis : axes) + { + tlengths.push_back(get_lengths()[axis]); + tstrides.push_back(get_strides()[axis]); + } + HostTensor ret(*this); + ret.mDesc = HostTensorDescriptor(tlengths, tstrides); + return ret; + } + + HostTensor transpose(std::vector axes = {}) + { + return const_cast const*>(this)->transpose(axes); + } + typename Data::iterator begin() { return mData.begin(); } typename Data::iterator end() { return mData.end(); } diff --git a/include/ck_tile/host/reference/reference_batched_dropout.hpp b/include/ck_tile/host/reference/reference_batched_dropout.hpp new file mode 100644 index 0000000000..242101bf4d --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_dropout.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_dropout(HostTensor& in_out_b_m_n, + const HostTensor& randval_b_m_n, + const uint8_t& p_undrop_in_uint8_t, + const float scale) +{ + const int N = in_out_b_m_n.mDesc.get_lengths()[2]; + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + float tmp = ck_tile::type_convert(in_out_b_m_n(batch, m, n)) * scale; + in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t + ? ck_tile::type_convert(tmp) + : DataType(0); + } + }; + + make_ParallelTensorFunctor( + f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1122bf87b7..5684868306 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -4,10 +4,24 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp new file mode 100644 index 0000000000..1f0fe2bd64 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + else + { + return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + }(); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + + auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx1 = + tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval.kThreadElementSpaceSize == 16); + + const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + int block_row_start = (start_m0_idx / WG::kM) + i_m0; + int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + randval(r_idx) = random_uint8_t[i_random_idx++]; + constexpr auto p_idx0 = + tile_distributed_index{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + ck_tile::philox ph; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 7fb1c19b5f..f43de4573a 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,6 +141,36 @@ struct GenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -160,14 +190,14 @@ struct GenericAttentionMask } else { - return i_x >= x_end; + return i_x >= x_end || i_y >= y_total; } } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -283,13 +343,13 @@ struct SimplifiedGenericAttentionMask index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end; + return i_x < x_start || i_x >= x_end || i_y >= y_total; } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size, { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp new file mode 100644 index 0000000000..e713cefbda --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -0,0 +1,1421 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q] +// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v] +// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v]) +// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q]) +// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k] +// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1] +// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1] + +namespace ck_tile { + +template +struct FmhaBwdDQDKDVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using KGradEpiloguePipeline = ck_tile::remove_cvref_t; + using VGradEpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using GemmDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + using QGradDataType = ck_tile::remove_cvref_t; + using KGradDataType = ck_tile::remove_cvref_t; + using VGradDataType = ck_tile::remove_cvref_t; + using BiasGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaBwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lse_ptr; + const void* do_ptr; + const void* d_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_head_q; + ck_tile::index_t nhead_ratio_qk; + float raw_scale; +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale; +#endif + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + + ck_tile::index_t batch_stride_lsed; + }; + + struct FmhaBwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaBwdBatchModeBiasKargs : FmhaBwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaBwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaBwdCommonBiasGradKargs + { + void* dbias_ptr = nullptr; + ck_tile::index_t stride_dbias = 0; + ck_tile::index_t nhead_stride_dbias = 0; + }; + + struct FmhaBwdBatchModeBiasGradKargs : FmhaBwdCommonBiasGradKargs + { + ck_tile::index_t batch_stride_dbias = 0; + }; + + struct FmhaBwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaBwdCommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset, + const float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaBwdBatchModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + }; + + struct FmhaBwdGroupModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dk, + ck_tile::index_t batch_stride_dv, + ck_tile::index_t batch_stride_dbias, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_do, + batch_stride_dk, + batch_stride_dv}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + kargs.batch_stride_dbias = batch_stride_dbias; + } + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), + KGradEpiloguePipeline::GetSmemSize(), + VGradEpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_lsed = 0; + long_index_t batch_offset_dk = 0; + long_index_t batch_offset_dv = 0; + long_index_t batch_offset_dbias = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = query_start * kargs.stride_dbias; + } + else + { + batch_offset_dbias = key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_k <= i_n0) + { + return; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = static_cast(i_batch) * kargs.batch_stride_dbias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + const LSEDataType* lse_ptr = reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_dk; + VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_dv; + + // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto qt_dram_naive = + transform_tensor_view(q_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto qt_dram = [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto kt_dram_naive = + transform_tensor_view(k_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto kt_dram = [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kVLoadOnce) + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + lse_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + const auto do_dram = [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto dot_dram_naive = + transform_tensor_view(do_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto dot_dram = [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto dq_dram = [&]() { + const auto dq_dram_naive = make_naive_tensor_view( + dq_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto qt_dram_window = + make_tile_window(qt_dram, + [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto kt_dram_window = + make_tile_window(kt_dram, + [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, i_n0}); + + auto v_dram_window = make_tile_window( + v_dram, + [&]() { + if constexpr(FmhaPipeline::kVLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto do_dram_window = make_tile_window( + do_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto dot_dram_window = + make_tile_window(dot_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto dq_dram_window = make_tile_window( + dq_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto lse_dram_window = + make_tile_window(lse_dram, make_tuple(number{}), {0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dbias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(kHasBiasGrad) + { + BiasGradDataType* dbias_ptr = + reinterpret_cast(kargs.dbias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dbias + + batch_offset_dbias; + + auto dbias_dram = [&]() { + const auto dbias_dram_naive = + make_naive_tensor_view( + dbias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_dbias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(dbias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + AccDataType slope = *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + // dropout + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + scale_rp_undrop = kargs.scale_rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + qt_dram_window, + k_dram_window, + kt_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + dot_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dbias_dram_window, + mask, + position_encoding, + kargs.raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + kargs.scale, +#endif + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); + + auto dk_dram = [&]() { + const auto dk_dram_naive = make_naive_tensor_view( + dk_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_dk, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dk_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dv_dram = [&]() { + const auto dv_dram_naive = make_naive_tensor_view( + dv_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_dv, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dv_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dk_dram_window = make_tile_window( + dk_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + auto dv_dram_window = make_tile_window( + dv_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); + VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + } +}; + +template +struct FmhaBwdOGradDotOKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; + static constexpr ck_tile::index_t kM0 = kBlockSize; + static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim; + + using DDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdOGradDotOCommonKargs + { + const void* o_ptr; + const void* do_ptr; + void* d_ptr; + + float p_undrop; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t stride_do; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_d; + ck_tile::index_t batch_stride_d; + }; + + struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs + { + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + seqlen_q, + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + batch_stride_do, + batch_stride_o}; + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + -1, // seqlen will be updated by another pointer + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + reinterpret_cast(seqstart_q_ptr)}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + + long_index_t batch_offset_o = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_d = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.stride_o; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + } + + // for simplicity, batch stride we just modify the pointer + const ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_d + + batch_offset_d; + + // O/dO/D DRAM and DRAM window + const auto o_dram = [&]() { + auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto do_dram = [&]() { + auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + return pad_tensor_view(do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto do_dram_window = + make_tile_window(do_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {i_m0}); + + FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp new file mode 100644 index 0000000000..bc875b8e5a --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaBwdTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +template +struct FmhaBwdOGradDotOTilePartitioner +{ + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9992d56ea9..6624bf1a9b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,11 +9,11 @@ #include #include -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] -// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] namespace ck_tile { @@ -32,6 +32,8 @@ struct FmhaFwdKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -45,6 +47,7 @@ struct FmhaFwdKernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -84,7 +87,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -111,6 +114,7 @@ struct FmhaFwdKernel ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; + ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; @@ -163,11 +167,35 @@ struct FmhaFwdKernel { void* lse_ptr = nullptr; ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + struct FmhaFwdCommonDropoutKargs { - ck_tile::index_t batch_stride_lse = 0; + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; }; struct FmhaFwdBatchModeKargs @@ -178,8 +206,9 @@ struct FmhaFwdKernel FmhaFwdAlibiKargs, FmhaFwdEmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -196,7 +225,8 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -211,12 +241,14 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -225,22 +257,28 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -250,6 +288,7 @@ struct FmhaFwdKernel seqlen_k, hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -268,6 +307,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v, @@ -302,6 +342,15 @@ struct FmhaFwdKernel kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -312,6 +361,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -319,6 +369,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -327,16 +378,22 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -346,6 +403,7 @@ struct FmhaFwdKernel -1, // hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -364,6 +422,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -389,12 +448,21 @@ struct FmhaFwdKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -426,12 +494,13 @@ struct FmhaFwdKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -455,7 +524,11 @@ struct FmhaFwdKernel } if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; } batch_offset_o = query_start * kargs.stride_o; @@ -493,6 +566,11 @@ struct FmhaFwdKernel { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -666,6 +744,62 @@ struct FmhaFwdKernel } }(); + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -723,6 +857,7 @@ struct FmhaFwdKernel identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func + randval_dram_window, lse_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func @@ -731,7 +866,8 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } else { @@ -739,11 +875,13 @@ struct FmhaFwdKernel k_dram_window, v_dram_window, bias_dram_window, + randval_dram_window, lse_dram_window, mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index e40b006685..2dca84b786 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp new file mode 100644 index 0000000000..f189937038 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdOGradDotO +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kVHeaddim = Problem::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + template + CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + DDramBlockWindowTmp& d_dram_block_window_tmp, + float p_undrop) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == + OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + auto o_dram_window = + make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(), + o_dram_block_window_tmp.get_window_lengths(), + o_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreODramTileDistribution()); + + auto o = load_tile(o_dram_window); + + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + do_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreOGradDramTileDistribution()); + + auto do_ = load_tile(do_dram_window); + + // declare d + constexpr auto d_dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); + + auto d = make_static_distributed_tensor(d_dstr); + + clear_tile(d); // Initialize D + + constexpr auto o_spans = decltype(o)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + d(i_idx) += + (type_convert(o[i_j_idx]) * type_convert(do_[i_j_idx])); + }); + }); + + tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d); + + store_tile(d_dram_block_window_tmp, d); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp new file mode 100644 index 0000000000..7843ab33a1 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// These templates are not used here. +using BlockFmhaBwdOGradDotODefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp new file mode 100644 index 0000000000..3444567508 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -0,0 +1,848 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSKTSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = true; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_kts_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& kt_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto kt_lds = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto kt_dram_block_window = kt_dram_block_window_tmp; + + auto kt_dram_window = make_tile_window( + kt_dram_block_window.get_bottom_tensor_view(), + kt_dram_block_window.get_window_lengths(), + kt_dram_block_window.get_window_origin(), + Policy::template MakeKTDramTileDistribution()); // K^T DRAM tile window for + // load + + auto kt_block_tile = load_tile(kt_dram_window); + + auto kt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledKTRegBlockDescriptor()); + shuffle_tile(kt_shuffle_tmp, kt_block_tile); + + store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp new file mode 100644 index 0000000000..a05fbf252f --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k & k^t located in lds. +using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp new file mode 100644 index 0000000000..dec421c1e6 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -0,0 +1,821 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp new file mode 100644 index 0000000000..cc4e6304d0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k located in lds. +using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp new file mode 100644 index 0000000000..97487bb71e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -0,0 +1,692 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = true; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = true; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "qs_ks_vr_dos"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + auto qt_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + auto dot_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + clear_tile(st_acc); // Initialize S^T + store_tile(q_lds_window, q_block_tile); // LDS write + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + }); + } + + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + block_sync_lds(); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + block_sync_lds(); + store_tile(do_lds_window, do_block_tile); // store the prefetch + + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence{}, sequence<(i_k1 + 1) * kK1, kN0>{}), + get_slice_tile(dot_lds_window, + sequence<0, i_k1 * kK1>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + clear_tile(dpt_acc); // Initialize PGrad^T + + static_for<0, k2_loops, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + get_slice_tile(do_lds_window, + sequence<0, i_k2 * kK2>{}, + sequence{}), + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + }); + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + block_sync_lds(); + const auto dst_gemm = cast_tile(dst); + + static_for<0, k3_loops, 1>{}([&](auto i_k3) { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence{}, sequence<(i_k3 + 1) * kK3, kN0>{}), + get_slice_tile(qt_lds_window, + sequence<0, i_k3 * kK3>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp new file mode 100644 index 0000000000..ac81990e07 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, q & k & do located in lds. +using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..a013ee3d57 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -0,0 +1,1343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineDefaultPolicy +{ + static constexpr bool QLoadOnce = + QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once + static constexpr bool QTLoadOnce = + QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once + static constexpr bool KLoadOnce = + KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once + static constexpr bool KTLoadOnce = + KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once + static constexpr bool VLoadOnce = + VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once + static constexpr bool OGradLoadOnce = + OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once + static constexpr bool OGradTLoadOnce = + OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(VLoadOnce) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + else + { + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 32) + return 8; + else + return 4; + } + + // these are for lds + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + { + // TODO: this is for 3d layout + using BiasDataType = remove_cvref_t; + return 16 / sizeof(BiasDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() + { + // TODO: this is for 3d layout + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() + { + // TODO: this is for 3d layout + using GemmDataType = remove_cvref_t; + return 16 / sizeof(GemmDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return x_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + static_assert(PixelsPerRow % KPack == 0); + constexpr index_t NPerRow = PixelsPerRow / KPack; + static_assert(MNPerBlock % NPerRow == 0); + static_assert(KPerBlock % KPack == 0); + + constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + xt_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPack = GetSmemKPackV(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPack = GetSmemKPackSGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType); + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() + { + using KDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() + { + using QGradDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() + { + using BiasDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType); + constexpr index_t kKPack = GetSmemKPackBias(); + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kMPerBlock % kKPack == 0); + + constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto biast_lds_block_desc = transform_tensor_descriptor( + biast_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return biast_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(); + return smem_size_q; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() + { + constexpr index_t smem_size_qt = [&]() { + if constexpr(QLoadOnce && !QTLoadOnce) + return 0; + else + return sizeof(typename Problem::QDataType) * + MakeQTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_qt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * + MakeKLdsBlockDescriptor().get_element_space_size(); + return smem_size_k; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() + { + constexpr index_t smem_size_kt = [&]() { + if constexpr(KLoadOnce && !KTLoadOnce) + return 0; + else + return sizeof(typename Problem::KDataType) * + MakeKTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_kt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + constexpr index_t smem_size_v = [&]() { + if constexpr(VLoadOnce) + return 0; + else + return sizeof(typename Problem::VDataType) * + MakeVLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_v; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() + { + constexpr index_t smem_size_do = + sizeof(typename Problem::OGradDataType) * + MakeOGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_do; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() + { + constexpr index_t smem_size_dot = [&]() { + if constexpr(OGradLoadOnce && !OGradTLoadOnce) + return 0; + else + return sizeof(typename Problem::OGradDataType) * + MakeOGradTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_dot; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() + { + constexpr index_t smem_size_ds = + sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_ds; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() + { + constexpr index_t smem_size_bias = [&]() { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasTLdsBlockDescriptor().get_element_space_size(); + else + return 0; + }(); + return smem_size_bias; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_qt = GetSmemSizeQT(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_kt = GetSmemSizeKT(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_dot = GetSmemSizeOGradT(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); + + index_t smem_size = 0; + + if constexpr(QLoadOnce && OGradLoadOnce) + smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + + smem_size_transpose; // 1~4 & 10 + else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += smem_size_q + smem_size_qt + + max(smem_size_do, + smem_size_dot, + smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) + smem_size += smem_size_do + smem_size_dot + + max(smem_size_q, + smem_size_qt, + smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += max(smem_size_q, + smem_size_qt, + smem_size_do, + smem_size_dot, + smem_size_transpose); // 9/13 TODO: Multiple buffers strategy + + // 14/15 needs to be adjusted + if constexpr(KLoadOnce) + smem_size += (smem_size_k + smem_size_kt); // 1~13 + else + smem_size = + max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy + + return max(smem_size, smem_size_v); // 15 + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + + constexpr index_t K1 = GetAlignmentOGrad(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + static_assert(kMPerBlock == M0 * M1 * M2 * M3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<1, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBRegCRegV1CustomPolicy; + + return BlockGemmASmemBRegCRegV1{}; + } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + // { + // using BlockGemmProblem = + // BlockGemmPipelineProblem>; + // constexpr auto warp_gemm = []() { + // if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + // } + // else if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + // } + // }(); + + // using BlockGemmPolicy = + // BlockGemmASmemBSmemCRegV1CustomPolicy; + + // return BlockGemmASmemBSmemCRegV1{}; + // } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp new file mode 100644 index 0000000000..a54a9fcb32 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaBwdPipelineEnum +{ + KSKTSVR = 0, + QSKSVROGradS, + KSVR, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp new file mode 100644 index 0000000000..7b787e9f36 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +template +struct BlockFmhaBwdOGradDotOPipelineProblem +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kVHeaddim = kVHeaddim_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 159fb40743..1b72b60054 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -13,6 +13,7 @@ template struct BlockFmhaPipelineProblem { - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; @@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 60650761d8..06ce3a6514 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -14,19 +15,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS }); }); + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + block_sync_lds(); if constexpr(std::is_same_v) { @@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto @@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS mask, position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8a19deb02a..9939a474b2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -15,19 +16,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVSAsync { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -118,6 +121,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -137,6 +141,7 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -145,7 +150,8 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -292,6 +298,9 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -558,6 +567,17 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -688,6 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto @@ -695,11 +716,13 @@ struct BlockFmhaPipelineQRKSVSAsync const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -709,6 +732,7 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -717,7 +741,8 @@ struct BlockFmhaPipelineQRKSVSAsync mask, position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 80f40f8154..f4767de0e9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,19 +14,20 @@ namespace ck_tile { template struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -106,20 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, PositionEncoding /*position_encoding*/, float scale_s, float descale_qk, float descale_sv, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& /*dropout*/) const // not supported { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index e12e767069..bc9ca93d09 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,19 +13,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQSKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4fda6f008f..12af81bb98 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { // TODO: assume Q is in register // TODO: assume K/V has same data type @@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + if constexpr(AsyncCopyK) + { + return GetSmemSizeKV() + GetSmemSizeDropout(); + } + else + { + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + if constexpr(Problem::kHasDropout) + { + constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm(); + constexpr auto config = + decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + + return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); + } + else + { + return 0; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index d8a290b09c..64a61e94d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -43,4 +43,53 @@ struct TileFmhaShape ck_tile::tensor_layout::gemm::ColumnMajor>; }; +template +struct TileFmhaBwdShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + using Gemm2BlockWarps = remove_cvref_t; + using Gemm2WarpTile = remove_cvref_t; + using Gemm3BlockWarps = remove_cvref_t; + using Gemm3WarpTile = remove_cvref_t; + using Gemm4BlockWarps = remove_cvref_t; + using Gemm4WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + + static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) && + NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{})); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = + BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll + static constexpr index_t kK1 = + BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll + static constexpr index_t kK2 = + BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll + static constexpr index_t kK3 = + BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll + static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or + // K/K^T at once + static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline + // that need load V at once +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 6cb6449f16..973ffa9f82 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,7 +13,9 @@ template struct TileFmhaTraits @@ -23,9 +25,21 @@ struct TileFmhaTraits static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaBwdOGradDotOTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c7ebcf9606..a89536e6eb 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -3,20 +3,21 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp deleted file mode 100644 index 1053c751ad..0000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -// Problem Description for BlockGemmARegBGmemCReg -template -struct BlockGemmARegBGmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index 7799bbe918..f097790ae6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem, + BlockGemmProblem, BlockGemmARegBSmemCRegV1DefaultPolicy>; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp index 4156398bd3..0a17b05353 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index aac9c4f552..84883d6ed8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" namespace ck_tile { @@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1 }); } - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iNWarp = get_warp_id() % NWarp; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - // constrcut from A-block-tensor from A-Block-tensor-tmp - // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent - // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); - - a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // Construct C-Block-HostTensor - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - using AWarpDstr = typename WG::AWarpDstr; - using CWarpDstr = typename WG::CWarpDstr; - - using AWarpTensor = typename WG::AWarpTensor; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp index 779113d96a..f998c67c95 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 8073989264..9b10d435b6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 405d7f1258..4a82702c1f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp index 8bcd04b7b0..20dcf2c270 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp index c17385b8e5..e90500c28c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp new file mode 100644 index 0000000000..65ce1a9b8f --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmASmemBRegCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode); + + // constrcut from B-block-tensor from B-Block-tensor-tmp + // FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent + // distribution + auto b_block_tensor = + make_static_distributed_tensor(b_block_dstr); + + b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer(); + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..5a17578f69 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmASmemBRegCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..cd16f09c37 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBRegCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBRegCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp deleted file mode 100644 index ed772891a4..0000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -// Problem Description for BlockGemmASmemBSmemCRegV1 -template -struct BlockGemmASmemBSmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 40da16d820..ac45221709 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 319711088f..2436457ec1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index fbb957727d..f798d6e815 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp similarity index 88% rename from include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp rename to include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index 7a0390a8a2..d8f66c81ca 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -7,13 +7,13 @@ namespace ck_tile { -// Problem Description for BlockGemmARegBSmemCReg +// Problem Description for BlockGemm template -struct BlockGemmARegBSmemCRegProblem +struct BlockGemmProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index dfc63f04c6..5b4419b79f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; -using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; @@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; -using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 71c59bbd17..fd5b004d36 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB } }; +template +struct WarpGemmAtrributeMfmaIterateK_SwizzleA +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + + return c_vec; + } +}; + } // namespace ck_tile