From e1a21655ae3e177b5503f7af65ef8d42df66aa12 Mon Sep 17 00:00:00 2001 From: danyao12 Date: Thu, 9 May 2024 17:08:08 +0800 Subject: [PATCH] FA bwd --- example/ck_tile/01_fmha/CMakeLists.txt | 33 +- example/ck_tile/01_fmha/fmha_bwd.cpp | 510 ++++--- example/ck_tile/01_fmha/fmha_bwd.hpp | 346 +++++ example/ck_tile/01_fmha/generate.py | 568 ++++++- .../ck_tile/01_fmha/script/benchmark_bwd.sh | 21 + .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 33 + include/ck_tile/core.hpp | 2 + .../core/arch/amd_buffer_addressing.hpp | 8 +- .../core/arch/generic_memory_space_atomic.hpp | 175 +++ include/ck_tile/core/tensor/buffer_view.hpp | 13 +- include/ck_tile/core/tensor/tensor_view.hpp | 30 +- include/ck_tile/core/tensor/tile_window.hpp | 60 + include/ck_tile/core/tensor/update_tile.hpp | 55 + include/ck_tile/ops/epilogue.hpp | 1 + .../ops/epilogue/custom_2d_epilogue.hpp | 41 + include/ck_tile/ops/fmha.hpp | 13 + .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 1331 ++++++++++++++++ .../fmha/kernel/fmha_bwd_tile_partitioner.hpp | 54 + .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 10 +- .../fmha/pipeline/block_fmha_bwd_dot_do_o.hpp | 95 ++ ...block_fmha_bwd_dot_do_o_default_policy.hpp | 20 + ...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp | 821 ++++++++++ ...k_dv_pipeline_ks_kts_vr_default_policy.hpp | 20 + ...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp | 794 ++++++++++ ...dq_dk_dv_pipeline_ks_vr_default_policy.hpp | 20 + ...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp | 665 ++++++++ ...v_pipeline_qs_ks_vr_dos_default_policy.hpp | 20 + ...block_fmha_bwd_pipeline_default_policy.hpp | 1343 +++++++++++++++++ .../pipeline/block_fmha_bwd_pipeline_enum.hpp | 16 + .../block_fmha_bwd_pipeline_problem.hpp | 91 ++ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 4 +- .../block/block_gemm_areg_bsmem_creg_v1.hpp | 159 +- .../block/block_gemm_asmem_breg_creg_v1.hpp | 228 +++ ..._gemm_asmem_breg_creg_v1_custom_policy.hpp | 36 + ...gemm_asmem_breg_creg_v1_default_policy.hpp | 56 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 12 +- 36 files changed, 7275 insertions(+), 429 deletions(-) create mode 100644 example/ck_tile/01_fmha/fmha_bwd.hpp create mode 100644 example/ck_tile/01_fmha/script/benchmark_bwd.sh create mode 100644 example/ck_tile/01_fmha/script/smoke_test_bwd.sh create mode 100644 include/ck_tile/core/arch/generic_memory_space_atomic.hpp create mode 100644 include/ck_tile/core/tensor/update_tile.hpp create mode 100644 include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e31c96caaa..027ce763e0 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,17 +1,29 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt + --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) -# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") @@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding tile_example ${EXAMPLE_NAME}") +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) @@ -29,16 +49,21 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) +set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index ec52755f07..37a21b25ef 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#include "fmha_bwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + #include #include #include @@ -9,35 +14,28 @@ #include #include #include +#include -#include "ck/ck.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/tensor/tensor_view.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/common_header.hpp" +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" - -#include "common/arg_parser.hpp" -#include "fmha_bwd.hpp" -#include "mask.hpp" -#include "reference/reference_batched_elementwise.hpp" -#include "reference/reference_batched_gemm.hpp" -#include "reference/reference_batched_masking.hpp" -#include "reference/reference_batched_softmax.hpp" -#include "reference/reference_batched_dropout.hpp" -#include "utils.hpp" + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} auto create_args(int argc, char* argv[]) { - ArgParser arg_parser; + ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "weather do CPU validation or not") .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "2", "batch size") @@ -69,11 +67,11 @@ auto create_args(int argc, char* argv[]) "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" "'xt:window_size', xformer style masking from top-left, window_size negative is " - "causal, possitive is swa\n" + "causal, positive is swa\n" "'xb:window_size', xformer style masking from bottom-r, window_size negative is " - "causal, possitive is swa\n" + "causal, positive is swa\n" "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " - "now)\n") + "now)") .insert("kname", "0", "if set to 1 will print kernel name") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") .insert("seed", @@ -96,18 +94,18 @@ auto get_elimit(int /*init_method*/) { double rtol = 1e-2; double atol = 1e-2; - return ck::make_tuple(rtol, atol); + return ck_tile::make_tuple(rtol, atol); } template -bool run(const ArgParser& arg_parser) +bool run(const ck_tile::ArgParser& arg_parser) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck::index_t batch = arg_parser.get_int("b"); - ck::index_t nhead = arg_parser.get_int("h"); - ck::index_t nhead_k = arg_parser.get_int("h_k"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); if(nhead_k == 0) nhead_k = nhead; @@ -117,12 +115,12 @@ bool run(const ArgParser& arg_parser) return false; } - ck::index_t seqlen_q = arg_parser.get_int("s"); - ck::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); if(seqlen_k == 0) seqlen_k = seqlen_q; - ck::index_t hdim_q = arg_parser.get_int("d"); - ck::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v == 0) hdim_v = hdim_q; if(hdim_q % 2 != 0 || hdim_v % 2 != 0) @@ -136,7 +134,7 @@ bool run(const ArgParser& arg_parser) float scale = arg_parser.get_float("scale"); if(scale == .0f) - scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); bool use_bias = arg_parser.get_bool("bias"); bool use_dbias = arg_parser.get_bool("dbias"); @@ -178,7 +176,7 @@ bool run(const ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - StreamConfig stream_config{ + ck_tile::stream_config stream_config{ nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); @@ -209,7 +207,7 @@ bool run(const ArgParser& arg_parser) auto max_seqlen_k = std::numeric_limits::min(); // we will use max seqlen to decide grid size { - for(ck::index_t wb = 0; wb < batch; ++wb) + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; @@ -224,11 +222,10 @@ bool run(const ArgParser& arg_parser) max_seqlen_k = real_seqlen_k; } - using namespace ck::literals; - - flop += nhead * - (3_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T - 2_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + flop += nhead * (static_cast(3) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + static_cast(2) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + sizeof(KDataType) * real_seqlen_k * hdim_q + @@ -243,85 +240,97 @@ bool run(const ArgParser& arg_parser) } auto get_lengths = [&](bool permute, - ck::index_t b /*batch*/, - ck::index_t h /*nhead*/, - ck::index_t s /*seqlen*/, - ck::index_t d /*hdim*/) { + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { if(permute) - return std::array{b, h, s, d}; + return std::array{b, h, s, d}; else - return std::array{b, s, h, d}; + return std::array{b, s, h, d}; }; // host memory for storing all the tensor elements - const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck::index_t shape_seqlen_q = + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); - const ck::index_t shape_seqlen_k = + const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); - Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - Tensor v_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); - // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + // use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host // will not be used for verification at all (but will be copied to device anyway). - Tensor bias_host( - use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - Tensor lse_host(std::array{batch, nhead, max_seqlen_q}); - Tensor d_host(std::array{batch, nhead, max_seqlen_q}); - Tensor randval_host( + ck_tile::HostTensor bias_host( + use_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor lse_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor d_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor randval_host( p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) - : std::array{1, 1, 1, 1}); - Tensor dq_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - Tensor dk_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); - Tensor dv_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); - Tensor do_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - Tensor dbias_host( - use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, shape_seqlen_k) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor dq_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor dk_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor dv_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor do_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor dbias_host( + use_dbias + ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); if(init_method == 0) { - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); } else if(init_method == 1) { - ck::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - ck::utils::FillUniformDistribution{0.f, 1.f, seed}(do_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); } else if(init_method == 2) { - ck::utils::FillTrigValue{}(q_host); - ck::utils::FillTrigValue{}(k_host); - ck::utils::FillTrigValue{}(v_host); - ck::utils::FillTrigValue{}(bias_host); - ck::utils::FillTrigValue{}(do_host); + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(do_host); } - DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); - DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); - DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); - DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); - DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); - DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes()); - DeviceMem d_buf(d_host.GetElementSpaceSizeInBytes()); - DeviceMem randval_buf(randval_host.GetElementSpaceSizeInBytes()); - DeviceMem dq_buf(dq_host.GetElementSpaceSizeInBytes()); - DeviceMem dk_buf(dk_host.GetElementSpaceSizeInBytes()); - DeviceMem dv_buf(dv_host.GetElementSpaceSizeInBytes()); - DeviceMem do_buf(do_host.GetElementSpaceSizeInBytes()); - DeviceMem dbias_buf(dbias_host.GetElementSpaceSizeInBytes()); - DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -363,40 +372,39 @@ bool run(const ArgParser& arg_parser) /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. // setup stride_* arguments - const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); - const ck::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck::index_t stride_randval = (max_seqlen_k); - const ck::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); - const ck::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); - const ck::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); - const ck::index_t stride_dbias = (i_perm ? shape_seqlen_k : nhead * shape_seqlen_k); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck_tile::index_t stride_bias = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); // setup nhead_stride_* arguments - const ck::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); - const ck::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); - const ck::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); - const ck::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck::index_t nhead_stride_lsed = max_seqlen_q; - const ck::index_t nhead_stride_dbias = - (i_perm ? shape_seqlen_q * shape_seqlen_k : shape_seqlen_k); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_bias = 0; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; + const ck_tile::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); // setup batch_stride_* arguments - const ck::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); - const ck::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); - const ck::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - const ck::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); - const ck::index_t batch_stride_lsed = (nhead * max_seqlen_q); - const ck::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); - const ck::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); - const ck::index_t batch_stride_dbias = (nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), @@ -456,7 +464,7 @@ bool run(const ArgParser& arg_parser) batch_stride_dbias, mask.left, mask.right, - static_cast(mask.type), + static_cast(mask.type), p_drop, p_undrop, s_randval, @@ -486,42 +494,43 @@ bool run(const ArgParser& arg_parser) bool pass = true; - std::vector> q_host_refs; - std::vector> k_host_refs; - std::vector> v_host_refs; - std::vector> o_host_refs; - std::vector> randval_host_refs; - std::vector> p_hp_host_refs; - std::vector> p_lp_host_refs; + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; randval_buf.FromDevice(randval_host.data()); - for(ck::index_t wb = 0; wb < batch; ++wb) + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode - const ck::index_t b = (mode == mode_enum::batch ? wb : 0); - const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); - Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k - Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k - Tensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n - Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o - Tensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m - Tensor randval_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n - Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n - Tensor p_hp_host_ref( + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision - Tensor p_dropped_hp_host_ref( + ck_tile::HostTensor p_dropped_hp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision - Tensor p_lp_host_ref( + ck_tile::HostTensor p_lp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision - ck::index_t nr = nhead / nhead_k; + ck_tile::index_t nr = nhead / nhead_k; // clang-format off // permute @@ -539,62 +548,68 @@ bool run(const ArgParser& arg_parser) // reference // S = scale * Q * K^T - reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, ck::identity{}, ck::identity{}, [&](AccDataType x) { - return scale * x; - }); // s_g_m_n = scale * q_g_m_k@k_g_n_k + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k if(use_bias) { // clang-format off - Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, // real_seqlen_k] - reference_batched_elementwise( - s_host_ref, bias_host_ref, s_host_ref); + ck_tile:: + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); } if(mask.type == mask_enum::no_mask) { - reference_batched_masking(s_host_ref, - FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); } else if(mask.type == mask_enum::window_generic) { - reference_batched_masking( - s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); } else { // if left window size is negative, means causal // else means generic (for current batch) if(mask.left < 0) - reference_batched_masking( + ck_tile::reference_batched_masking( s_host_ref, - ck::make_generic_attention_mask_from_lr_window( + ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); else - reference_batched_masking( + ck_tile::reference_batched_masking( s_host_ref, - ck::make_generic_attention_mask_from_lr_window( + ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); } - reference_batched_softmax( - s_host_ref, p_hp_host_ref, lse_host_ref); + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); if(p_drop > 0) { @@ -603,21 +618,21 @@ bool run(const ArgParser& arg_parser) randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); - reference_batched_dropout( + ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck::type_convert(self(idx)); + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); }); } else { p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck::type_convert(self(idx)); + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); }); } // O = P * V - reference_batched_gemm( + ck_tile::reference_batched_gemm( p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n // clang-format off @@ -652,28 +667,28 @@ bool run(const ArgParser& arg_parser) dv_buf.FromDevice(dv_host.data()); dbias_buf.FromDevice(dbias_host.data()); - for(ck::index_t wb = 0; wb < batch; ++wb) + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode - const ck::index_t b = (mode == mode_enum::batch ? wb : 0); - const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); - Tensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o - Tensor ds_hp_host_ref( + ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision - Tensor ds_lp_host_ref( + ck_tile::HostTensor ds_lp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision - Tensor dp_hp_host_ref( + ck_tile::HostTensor dp_hp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision - Tensor dbias_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n - Tensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k - Tensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k - Tensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o // clang-format off if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); @@ -683,12 +698,12 @@ bool run(const ArgParser& arg_parser) // dP = dO@V x Z w/ dropout // dP = dO@V w/o dropout auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o - reference_batched_gemm( + ck_tile::reference_batched_gemm( do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o if(p_drop > 0) { - reference_batched_dropout( + ck_tile::reference_batched_dropout( dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); } @@ -699,56 +714,59 @@ bool run(const ArgParser& arg_parser) { auto idx_gmo = idx_gmn; idx_gmo[2] = o; - do_dot_o += ck::type_convert(do_host_ref(idx_gmo)) * - ck::type_convert(o_host_refs[wb](idx_gmo)); + do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * + ck_tile::type_convert(o_host_refs[wb](idx_gmo)); } - self(idx_gmn) = ck::type_convert(p_hp_host_refs[wb](idx_gmn) * - (dp_hp_host_ref(idx_gmn) - do_dot_o)); + self(idx_gmn) = ck_tile::type_convert( + p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); }); if(use_dbias) { ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck::type_convert(self(idx)); + dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); }); } ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck::type_convert(self(idx)); + ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); }); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m - reference_batched_gemm( + ck_tile::reference_batched_gemm( p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m // dQ = scale * dS@K^T auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n - reference_batched_gemm( + ck_tile::reference_batched_gemm( ds_lp_host_ref, k_t_host_ref, dq_host_ref, - ck::identity{}, - ck::identity{}, - [&scale](const AccDataType& x) { return scale * x; }); // dq_g_m_k = ds_g_m_n@k_g_k_n + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n // dK = scale * dS^T@Q^T auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m - reference_batched_gemm( + ck_tile::reference_batched_gemm( ds_t_lp_host_ref, q_t_host_ref, dk_host_ref, - ck::identity{}, - ck::identity{}, - [&scale](const AccDataType& x) { return scale * x; }); // dk_g_n_k = ds_g_n_m@q_g_k_m + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m - Tensor dq_host_result({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k - Tensor dk_host_result({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k - Tensor dv_host_result({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o - Tensor dbias_host_result( + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n // clang-format off @@ -764,36 +782,36 @@ bool run(const ArgParser& arg_parser) if(use_dbias) { - if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2] + key_offset); }); - else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2] + key_offset); }); + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); } // clang-format on auto [rtol, atol] = get_elimit(init_method); - bool dq_cur_pass = ck::utils::check_err(dq_host_result, - dq_host_ref, - std::string("Error: QGrad Incorrect results!"), - rtol, - atol); - bool dk_cur_pass = ck::utils::check_err(dk_host_result, - dk_host_ref, - std::string("Error: KGrad Incorrect results!"), - rtol, - atol); - bool dv_cur_pass = ck::utils::check_err(dv_host_result, - dv_host_ref, - std::string("Error: VGrad Incorrect results!"), - rtol, - atol); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); bool dbias_cur_pass = true; if(use_dbias) { - dbias_cur_pass = ck::utils::check_err(dbias_host_result, - dbias_host_ref, - std::string("Error: BiasGrad Incorrect results!"), - rtol, - atol); + dbias_cur_pass = ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); } pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) @@ -822,11 +840,11 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp new file mode 100644 index 0000000000..67817bbcd3 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_lsed, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bool has_bias; + bool has_dbias; + bool has_dropout; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index f41d3d3fff..d00ddd8cfd 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -65,7 +65,6 @@ BOOL_MAP = { "f" : "false" } -DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -469,7 +468,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -507,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) if d == None: continue @@ -536,39 +535,574 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw return (api_pool, gen) -def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_epilogue_{F_idx} = + ck_tile::FmhaBwdEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_dot_do_o_(s, a); + r += fmha_bwd_dq_dk_dv_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str # true/false + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}" + +@dataclass +class FmhaBwdDQDKDVKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BOOL_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def mask_name() -> str: + n = '' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + return n + # TODO: we don't encode idx here + mn = mask_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\ + f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ + f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" + if mn != '' : n += f'{mn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'bwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_kts_vr"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if (bias == "f" and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}" +\ + f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ + f"_o{self.F_occupancy}" + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_api(api_pool, output_dir) + if direction == 'fwd': + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + if direction == 'fwd': + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", description="gen api for CK fmha kernel", ) + parser.add_argument( + "-d", + "--direction", + default='fwd', + choices=['fwd', 'bwd'], + required=False, + help="choose the direction of kernels(default: fwd)" + ) parser.add_argument( "-o", "--output_dir", @@ -608,6 +1142,6 @@ if __name__ == "__main__": args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) + list_blobs(args.list_blobs, args.direction, args.filter, args.receipt, mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) + write_blobs(args.output_dir, args.direction, args.filter, args.receipt, mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh new file mode 100644 index 0000000000..7591f5442a --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh new file mode 100644 index 0000000000..0a26260df0 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -0,0 +1,33 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do +for mode in 0 1 ; do +for bias in 0 1 ; do +for dbias in 0 1 ; do +for p_drop in 0.0 0.2; do + +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bdf8d79d34..5a175a61c8 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" @@ -37,6 +38,7 @@ #include "ck_tile/core/tensor/slice_tile.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/store_tile.hpp" +#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/sweep_tile.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 39d755f0d9..a123623f5b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -765,21 +765,21 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); // buffer store ui16 -__device__ void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); -__device__ void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); -__device__ void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, int32x4_t rsrc, index_t voffset, @@ -1658,7 +1658,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_th { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp new file mode 100644 index 0000000000..6212db9169 --- /dev/null +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" + +namespace ck_tile { + +CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) +{ + return type_convert(type_convert(a) + type_convert(b)); +} + +CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) +{ + bf16x2_t rtn; + rtn[0] = add_bf16_t(a[0], b[0]); + rtn[1] = add_bf16_t(a[1], b[1]); + return rtn; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); + +template <> +CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) +{ + union U32BF162_ADDR + { + uint32_t* u32_a; + bf16x2_t* bf162_a; + }; + + union U32BF162 + { + uint32_t u32; + bf16x2_t bf162; + }; + + U32BF162_ADDR dword_addr; + U32BF162 cur_v; + U32BF162 new_; + uint32_t old_v, new_v; + dword_addr.bf162_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf162 = add_bf16x2_t(cur_v.bf162, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template +CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + return atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 2) + { + atomic_add(c_style_pointer_cast(p_dst), bit_cast(x)); + } + else if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } +} + +template +CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicMax(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicMax(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 96b38241c0..ffe8f7a4fd 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" @@ -507,10 +508,10 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } @@ -518,7 +519,7 @@ struct buffer_view(c_style_pointer_cast(&p_data_[i]), x); + atomic_add_g, t_per_x>(&p_data_[i], x); } } } @@ -547,16 +548,16 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } else if(is_valid_element) { - atomic_max(c_style_pointer_cast(&p_data_[i]), x); + atomic_max_g, t_per_x>(&p_data_[i], x); } } diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index e37bd806de..656309532e 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -16,7 +16,9 @@ namespace ck_tile { -template +template struct tensor_view { using buffer_view = remove_reference_t; @@ -24,6 +26,7 @@ struct tensor_view using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + static constexpr auto DstInMemOp = DstInMemOp_; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -140,6 +143,23 @@ struct tensor_view x); } + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template update( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); @@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, } template (p, desc.get_element_space_size()); - return tensor_view{buffer_view, desc}; + return tensor_view{buffer_view, desc}; } template >{ - old_tensor_view.buf_, new_desc}; + return tensor_view, + remove_cvref_t::DstInMemOp>{old_tensor_view.buf_, new_desc}; } template + CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp new file mode 100644 index 0000000000..fbce7c4083 --- /dev/null +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.update(dstr_tensor); +} + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.update(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 388f52c898..2b99ca162b 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -4,4 +4,5 @@ #pragma once #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/custom_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp new file mode 100644 index 0000000000..aeff6237a2 --- /dev/null +++ b/include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaBwdEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; +}; + +template +struct FmhaBwdEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE auto operator()(KGradDramWindowTmp& dk_dram_window_tmp, + VGradDramWindowTmp& dv_dram_window_tmp, + const KGradAccTile& dk_acc_tile, + const VGradAccTile& dv_acc_tile) + { + store_tile(dk_dram_window_tmp, cast_tile(dk_acc_tile)); + store_tile(dv_dram_window_tmp, cast_tile(dv_acc_tile)); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 9d08a55bf6..0578f8abe2 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -17,6 +17,19 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp new file mode 100644 index 0000000000..04dbe2a4ef --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -0,0 +1,1331 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q] +// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v] +// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v]) +// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q]) +// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k] +// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1] +// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1] + +namespace ck_tile { + +template +struct FmhaBwdDQDKDVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using GemmDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + using QGradDataType = ck_tile::remove_cvref_t; + using KGradDataType = ck_tile::remove_cvref_t; + using VGradDataType = ck_tile::remove_cvref_t; + using BiasGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (kHasBias ? "_bias" : "") + + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaBwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lse_ptr; + const void* do_ptr; + const void* d_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_head_q; + ck_tile::index_t nhead_ratio_qk; + float raw_scale; +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale; +#endif + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + + ck_tile::index_t batch_stride_lsed; + }; + + struct FmhaBwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaBwdBatchModeBiasKargs : FmhaBwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaBwdCommonBiasGradKargs + { + void* dbias_ptr = nullptr; + ck_tile::index_t stride_dbias = 0; + ck_tile::index_t nhead_stride_dbias = 0; + }; + + struct FmhaBwdBatchModeBiasGradKargs : FmhaBwdCommonBiasGradKargs + { + ck_tile::index_t batch_stride_dbias = 0; + }; + + struct FmhaBwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaBwdCommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset, + const float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaBwdBatchModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + }; + + struct FmhaBwdGroupModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dk, + ck_tile::index_t batch_stride_dv, + ck_tile::index_t batch_stride_dbias, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_do, + batch_stride_dk, + batch_stride_dv}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + kargs.batch_stride_dbias = batch_stride_dbias; + } + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_lsed = 0; + long_index_t batch_offset_dk = 0; + long_index_t batch_offset_dv = 0; + long_index_t batch_offset_dbias = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + else + { + batch_offset_bias = key_start; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = query_start * kargs.stride_dbias; + } + else + { + batch_offset_dbias = key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_k <= i_n0) + { + return; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = static_cast(i_batch) * kargs.batch_stride_dbias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + const LSEDataType* lse_ptr = reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_dk; + VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_dv; + + // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto qt_dram_naive = + transform_tensor_view(q_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto qt_dram = [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto kt_dram_naive = + transform_tensor_view(k_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto kt_dram = [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kVLoadOnce) + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + lse_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + const auto do_dram = [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto dot_dram_naive = + transform_tensor_view(do_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto dot_dram = [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto dq_dram = [&]() { + const auto dq_dram_naive = make_naive_tensor_view( + dq_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto qt_dram_window = + make_tile_window(qt_dram, + [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto kt_dram_window = + make_tile_window(kt_dram, + [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, i_n0}); + + auto v_dram_window = make_tile_window( + v_dram, + [&]() { + if constexpr(FmhaPipeline::kVLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto do_dram_window = make_tile_window( + do_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto dot_dram_window = + make_tile_window(dot_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto dq_dram_window = make_tile_window( + dq_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto lse_dram_window = + make_tile_window(lse_dram, make_tuple(number{}), {0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dbias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(kHasBiasGrad) + { + BiasGradDataType* dbias_ptr = + reinterpret_cast(kargs.dbias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dbias + + batch_offset_dbias; + + auto dbias_dram = [&]() { + const auto dbias_dram_naive = + make_naive_tensor_view( + dbias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_dbias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(dbias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // dropout + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + scale_rp_undrop = kargs.scale_rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + qt_dram_window, + k_dram_window, + kt_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + dot_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dbias_dram_window, + mask, + kargs.raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + kargs.scale, +#endif + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); + + auto dk_dram = [&]() { + const auto dk_dram_naive = make_naive_tensor_view( + dk_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_dk, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dk_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dv_dram = [&]() { + const auto dv_dram_naive = make_naive_tensor_view( + dv_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_dv, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dv_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dk_dram_window = make_tile_window( + dk_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + auto dv_dram_window = make_tile_window( + dv_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + EpiloguePipeline{}(dk_dram_window, dv_dram_window, dk_acc_tile, dv_acc_tile); + } +}; + +template +struct FmhaBwdOGradDotOKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; + static constexpr ck_tile::index_t kM0 = kBlockSize; + static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim; + + using DDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdOGradDotOCommonKargs + { + const void* o_ptr; + const void* do_ptr; + void* d_ptr; + + float p_undrop; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t stride_do; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_d; + ck_tile::index_t batch_stride_d; + }; + + struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs + { + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + seqlen_q, + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + batch_stride_do, + batch_stride_o}; + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + -1, // seqlen will be updated by another pointer + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + reinterpret_cast(seqstart_q_ptr)}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + + long_index_t batch_offset_o = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_d = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.stride_o; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + } + + // for simplicity, batch stride we just modify the pointer + const ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_d + + batch_offset_d; + + // O/dO/D DRAM and DRAM window + const auto o_dram = [&]() { + auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto do_dram = [&]() { + auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + return pad_tensor_view(do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto do_dram_window = + make_tile_window(do_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {i_m0}); + + FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp new file mode 100644 index 0000000000..bc875b8e5a --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaBwdTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +template +struct FmhaBwdOGradDotOTilePartitioner +{ + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index 52f458c72e..214ecb7b7c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,10 +18,10 @@ struct FmhaFwdTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp new file mode 100644 index 0000000000..f189937038 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdOGradDotO +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kVHeaddim = Problem::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + template + CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + DDramBlockWindowTmp& d_dram_block_window_tmp, + float p_undrop) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == + OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + auto o_dram_window = + make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(), + o_dram_block_window_tmp.get_window_lengths(), + o_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreODramTileDistribution()); + + auto o = load_tile(o_dram_window); + + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + do_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreOGradDramTileDistribution()); + + auto do_ = load_tile(do_dram_window); + + // declare d + constexpr auto d_dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); + + auto d = make_static_distributed_tensor(d_dstr); + + clear_tile(d); // Initialize D + + constexpr auto o_spans = decltype(o)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + d(i_idx) += + (type_convert(o[i_j_idx]) * type_convert(do_[i_j_idx])); + }); + }); + + tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d); + + store_tile(d_dram_block_window_tmp, d); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp new file mode 100644 index 0000000000..7843ab33a1 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// These templates are not used here. +using BlockFmhaBwdOGradDotODefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp new file mode 100644 index 0000000000..4b2c469ca9 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -0,0 +1,821 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSKTSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = true; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_kts_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& kt_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto kt_lds = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto kt_dram_block_window = kt_dram_block_window_tmp; + + auto kt_dram_window = make_tile_window( + kt_dram_block_window.get_bottom_tensor_view(), + kt_dram_block_window.get_window_lengths(), + kt_dram_block_window.get_window_origin(), + Policy::template MakeKTDramTileDistribution()); // K^T DRAM tile window for + // load + + auto kt_block_tile = load_tile(kt_dram_window); + + auto kt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledKTRegBlockDescriptor()); + shuffle_tile(kt_shuffle_tmp, kt_block_tile); + + store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(kHasBias) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp new file mode 100644 index 0000000000..a05fbf252f --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k & k^t located in lds. +using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp new file mode 100644 index 0000000000..ce81b3bfd6 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -0,0 +1,794 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(kHasBias) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp new file mode 100644 index 0000000000..cc4e6304d0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k located in lds. +using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp new file mode 100644 index 0000000000..5ffa7f8d50 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = true; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = true; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "qs_ks_vr_dos"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + auto qt_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + auto dot_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + clear_tile(st_acc); // Initialize S^T + store_tile(q_lds_window, q_block_tile); // LDS write + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + }); + } + + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + block_sync_lds(); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(kHasBias) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + block_sync_lds(); + store_tile(do_lds_window, do_block_tile); // store the prefetch + + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence{}, sequence<(i_k1 + 1) * kK1, kN0>{}), + get_slice_tile(dot_lds_window, + sequence<0, i_k1 * kK1>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + clear_tile(dpt_acc); // Initialize PGrad^T + + static_for<0, k2_loops, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + get_slice_tile(do_lds_window, + sequence<0, i_k2 * kK2>{}, + sequence{}), + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + }); + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + block_sync_lds(); + const auto dst_gemm = cast_tile(dst); + + static_for<0, k3_loops, 1>{}([&](auto i_k3) { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence{}, sequence<(i_k3 + 1) * kK3, kN0>{}), + get_slice_tile(qt_lds_window, + sequence<0, i_k3 * kK3>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp new file mode 100644 index 0000000000..ac81990e07 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, q & k & do located in lds. +using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..ba840e725b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -0,0 +1,1343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineDefaultPolicy +{ + static constexpr bool QLoadOnce = + QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once + static constexpr bool QTLoadOnce = + QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once + static constexpr bool KLoadOnce = + KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once + static constexpr bool KTLoadOnce = + KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once + static constexpr bool VLoadOnce = + VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once + static constexpr bool OGradLoadOnce = + OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once + static constexpr bool OGradTLoadOnce = + OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(VLoadOnce) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + else + { + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 32) + return 8; + else + return 4; + } + + // these are for lds + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + { + // TODO: this is for 3d layout + using BiasDataType = remove_cvref_t; + return 16 / sizeof(BiasDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() + { + // TODO: this is for 3d layout + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() + { + // TODO: this is for 3d layout + using GemmDataType = remove_cvref_t; + return 16 / sizeof(GemmDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return x_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + static_assert(PixelsPerRow % KPack == 0); + constexpr index_t NPerRow = PixelsPerRow / KPack; + static_assert(MNPerBlock % NPerRow == 0); + static_assert(KPerBlock % KPack == 0); + + constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + xt_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPack = GetSmemKPackV(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPack = GetSmemKPackSGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType); + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() + { + using KDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() + { + using QGradDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() + { + using BiasDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType); + constexpr index_t kKPack = GetSmemKPackBias(); + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kMPerBlock % kKPack == 0); + + constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto biast_lds_block_desc = transform_tensor_descriptor( + biast_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return biast_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(); + return smem_size_q; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() + { + constexpr index_t smem_size_qt = [&]() { + if constexpr(QLoadOnce && !QTLoadOnce) + return 0; + else + return sizeof(typename Problem::QDataType) * + MakeQTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_qt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * + MakeKLdsBlockDescriptor().get_element_space_size(); + return smem_size_k; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() + { + constexpr index_t smem_size_kt = [&]() { + if constexpr(KLoadOnce && !KTLoadOnce) + return 0; + else + return sizeof(typename Problem::KDataType) * + MakeKTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_kt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + constexpr index_t smem_size_v = [&]() { + if constexpr(VLoadOnce) + return 0; + else + return sizeof(typename Problem::VDataType) * + MakeVLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_v; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() + { + constexpr index_t smem_size_do = + sizeof(typename Problem::OGradDataType) * + MakeOGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_do; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() + { + constexpr index_t smem_size_dot = [&]() { + if constexpr(OGradLoadOnce && !OGradTLoadOnce) + return 0; + else + return sizeof(typename Problem::OGradDataType) * + MakeOGradTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_dot; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() + { + constexpr index_t smem_size_ds = + sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_ds; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() + { + constexpr index_t smem_size_bias = [&]() { + if constexpr(Problem::kHasBias) + return sizeof(typename Problem::BiasDataType) * + MakeBiasTLdsBlockDescriptor().get_element_space_size(); + else + return 0; + }(); + return smem_size_bias; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_qt = GetSmemSizeQT(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_kt = GetSmemSizeKT(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_dot = GetSmemSizeOGradT(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); + + index_t smem_size = 0; + + if constexpr(QLoadOnce && OGradLoadOnce) + smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + + smem_size_transpose; // 1~4 & 10 + else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += smem_size_q + smem_size_qt + + max(smem_size_do, + smem_size_dot, + smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) + smem_size += smem_size_do + smem_size_dot + + max(smem_size_q, + smem_size_qt, + smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += max(smem_size_q, + smem_size_qt, + smem_size_do, + smem_size_dot, + smem_size_transpose); // 9/13 TODO: Multiple buffers strategy + + // 14/15 needs to be adjusted + if constexpr(KLoadOnce) + smem_size += (smem_size_k + smem_size_kt); // 1~13 + else + smem_size = + max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy + + return max(smem_size, smem_size_v); // 15 + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + + constexpr index_t K1 = GetAlignmentOGrad(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + static_assert(kMPerBlock == M0 * M1 * M2 * M3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<1, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBRegCRegV1CustomPolicy; + + return BlockGemmASmemBRegCRegV1{}; + } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + // { + // using BlockGemmProblem = + // BlockGemmPipelineProblem>; + // constexpr auto warp_gemm = []() { + // if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + // } + // else if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + // } + // }(); + + // using BlockGemmPolicy = + // BlockGemmASmemBSmemCRegV1CustomPolicy; + + // return BlockGemmASmemBSmemCRegV1{}; + // } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp new file mode 100644 index 0000000000..a54a9fcb32 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaBwdPipelineEnum +{ + KSKTSVR = 0, + QSKSVROGradS, + KSVR, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp new file mode 100644 index 0000000000..5ed41d6264 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasBias = Traits::kHasBias; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +template +struct BlockFmhaBwdOGradDotOPipelineProblem +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kVHeaddim = kVHeaddim_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 7b2940bd6b..12af81bb98 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -703,7 +703,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - __host__ __device__ static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { if constexpr(AsyncCopyK) { @@ -716,7 +716,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - __host__ __device__ static constexpr ck_tile::index_t GetSmemSizeDropout() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() { if constexpr(Problem::kHasDropout) { diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 80dda9f17d..84883d6ed8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" namespace ck_tile { @@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1 }); } - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iNWarp = get_warp_id() % NWarp; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - // constrcut from A-block-tensor from A-Block-tensor-tmp - // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent - // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); - - a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // Construct C-Block-HostTensor - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - using AWarpDstr = typename WG::AWarpDstr; - using CWarpDstr = typename WG::CWarpDstr; - - using AWarpTensor = typename WG::AWarpTensor; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp new file mode 100644 index 0000000000..65ce1a9b8f --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmASmemBRegCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode); + + // constrcut from B-block-tensor from B-Block-tensor-tmp + // FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent + // distribution + auto b_block_tensor = + make_static_distributed_tensor(b_block_dstr); + + b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer(); + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..5a17578f69 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmASmemBRegCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..cd16f09c37 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBRegCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBRegCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index f2e586f794..ec78bcf4e8 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -526,9 +526,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - reinterpret_cast(a_vec) + reinterpret_cast(a_vec) .template get_as()[iKIter], - reinterpret_cast(b_vec) + reinterpret_cast(b_vec) .template get_as()[iKIter]); }); } @@ -541,14 +541,14 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA using buf_b = thread_buffer; auto c_vec = Impl{}( - reinterpret_cast(a_vec).template get_as()[I0], - reinterpret_cast(b_vec).template get_as()[I0]); + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - reinterpret_cast(a_vec) + reinterpret_cast(a_vec) .template get_as()[iKIter], - reinterpret_cast(b_vec) + reinterpret_cast(b_vec) .template get_as()[iKIter]); });