This commit is contained in:
danyao12
2024-05-09 17:08:08 +08:00
parent bbd2e1eae3
commit e1a21655ae
36 changed files with 7275 additions and 429 deletions

View File

@@ -1,17 +1,29 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding tile_example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
@@ -29,16 +49,21 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2)
endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})

View File

@@ -1,6 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include <array>
#include <cstring>
#include <functional>
@@ -9,35 +14,28 @@
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/common_header.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "common/arg_parser.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
#include "reference/reference_batched_elementwise.hpp"
#include "reference/reference_batched_gemm.hpp"
#include "reference/reference_batched_masking.hpp"
#include "reference/reference_batched_softmax.hpp"
#include "reference/reference_batched_dropout.hpp"
#include "utils.hpp"
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ArgParser arg_parser;
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
@@ -69,11 +67,11 @@ auto create_args(int argc, char* argv[])
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, possitive is swa\n"
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, possitive is swa\n"
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)\n")
"now)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
.insert("seed",
@@ -96,18 +94,18 @@ auto get_elimit(int /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck::make_tuple(rtol, atol);
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType>
bool run(const ArgParser& arg_parser)
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck::index_t batch = arg_parser.get_int("b");
ck::index_t nhead = arg_parser.get_int("h");
ck::index_t nhead_k = arg_parser.get_int("h_k");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0)
nhead_k = nhead;
@@ -117,12 +115,12 @@ bool run(const ArgParser& arg_parser)
return false;
}
ck::index_t seqlen_q = arg_parser.get_int("s");
ck::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
if(seqlen_k == 0)
seqlen_k = seqlen_q;
ck::index_t hdim_q = arg_parser.get_int("d");
ck::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v == 0)
hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
@@ -136,7 +134,7 @@ bool run(const ArgParser& arg_parser)
float scale = arg_parser.get_float("scale");
if(scale == .0f)
scale = 1.0 / ck::math::sqrt(static_cast<float>(hdim_q));
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
bool use_bias = arg_parser.get_bool("bias");
bool use_dbias = arg_parser.get_bool("dbias");
@@ -178,7 +176,7 @@ bool run(const ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
StreamConfig stream_config{
ck_tile::stream_config stream_config{
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
@@ -209,7 +207,7 @@ bool run(const ArgParser& arg_parser)
auto max_seqlen_k =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
{
for(ck::index_t wb = 0; wb < batch; ++wb)
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
@@ -224,11 +222,10 @@ bool run(const ArgParser& arg_parser)
max_seqlen_k = real_seqlen_k;
}
using namespace ck::literals;
flop += nhead *
(3_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
2_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q +
@@ -243,85 +240,97 @@ bool run(const ArgParser& arg_parser)
}
auto get_lengths = [&](bool permute,
ck::index_t b /*batch*/,
ck::index_t h /*nhead*/,
ck::index_t s /*seqlen*/,
ck::index_t d /*hdim*/) {
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) {
if(permute)
return std::array<ck::index_t, 4>{b, h, s, d};
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck::index_t, 4>{b, s, h, d};
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
// host memory for storing all the tensor elements
const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck::index_t shape_seqlen_q =
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck::index_t shape_seqlen_k =
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
Tensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
Tensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
Tensor<VDataType> v_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VDataType> v_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
// use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
Tensor<BiasDataType> bias_host(
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
Tensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
Tensor<LSEDataType> lse_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
Tensor<DDataType> d_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
Tensor<RandValOutputDataType> randval_host(
ck_tile::HostTensor<BiasDataType> bias_host(
use_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck::index_t, 4>{1, 1, 1, 1});
Tensor<QGradDataType> dq_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
Tensor<KGradDataType> dk_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
Tensor<VGradDataType> dv_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
Tensor<OGradDataType> do_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
Tensor<BiasGradDataType> dbias_host(
use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, shape_seqlen_k)
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<QGradDataType> dq_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KGradDataType> dk_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VGradDataType> dv_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
ck_tile::HostTensor<OGradDataType> do_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<BiasGradDataType> dbias_host(
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
if(init_method == 0)
{
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
ck::utils::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
}
else if(init_method == 1)
{
ck::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
ck::utils::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
}
else if(init_method == 2)
{
ck::utils::FillTrigValue<QDataType>{}(q_host);
ck::utils::FillTrigValue<KDataType>{}(k_host);
ck::utils::FillTrigValue<VDataType>{}(v_host);
ck::utils::FillTrigValue<BiasDataType>{}(bias_host);
ck::utils::FillTrigValue<OGradDataType>{}(do_host);
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
}
DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes());
DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes());
DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes());
DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes());
DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes());
DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes());
DeviceMem d_buf(d_host.GetElementSpaceSizeInBytes());
DeviceMem randval_buf(randval_host.GetElementSpaceSizeInBytes());
DeviceMem dq_buf(dq_host.GetElementSpaceSizeInBytes());
DeviceMem dk_buf(dk_host.GetElementSpaceSizeInBytes());
DeviceMem dv_buf(dv_host.GetElementSpaceSizeInBytes());
DeviceMem do_buf(do_host.GetElementSpaceSizeInBytes());
DeviceMem dbias_buf(dbias_host.GetElementSpaceSizeInBytes());
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -363,40 +372,39 @@ bool run(const ArgParser& arg_parser)
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
const ck::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck::index_t stride_randval = (max_seqlen_k);
const ck::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
const ck::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck::index_t stride_dbias = (i_perm ? shape_seqlen_k : nhead * shape_seqlen_k);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
const ck_tile::index_t stride_bias = (max_seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
// setup nhead_stride_* arguments
const ck::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
const ck::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck::index_t nhead_stride_lsed = max_seqlen_q;
const ck::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * shape_seqlen_k : shape_seqlen_k);
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_bias = 0;
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
const ck::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
const ck::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck::index_t batch_stride_dbias = (nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_bias = 0;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
@@ -456,7 +464,7 @@ bool run(const ArgParser& arg_parser)
batch_stride_dbias,
mask.left,
mask.right,
static_cast<ck::index_t>(mask.type),
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
s_randval,
@@ -486,42 +494,43 @@ bool run(const ArgParser& arg_parser)
bool pass = true;
std::vector<Tensor<QDataType>> q_host_refs;
std::vector<Tensor<KDataType>> k_host_refs;
std::vector<Tensor<VDataType>> v_host_refs;
std::vector<Tensor<ODataType>> o_host_refs;
std::vector<Tensor<RandValOutputDataType>> randval_host_refs;
std::vector<Tensor<AccDataType>> p_hp_host_refs;
std::vector<Tensor<GemmDataType>> p_lp_host_refs;
std::vector<ck_tile::HostTensor<QDataType>> q_host_refs;
std::vector<ck_tile::HostTensor<KDataType>> k_host_refs;
std::vector<ck_tile::HostTensor<VDataType>> v_host_refs;
std::vector<ck_tile::HostTensor<ODataType>> o_host_refs;
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
randval_buf.FromDevice(randval_host.data());
for(ck::index_t wb = 0; wb < batch; ++wb)
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
Tensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
Tensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
Tensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
Tensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
Tensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
Tensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
Tensor<AccDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
Tensor<AccDataType> p_hp_host_ref(
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
ck_tile::HostTensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
ck_tile::HostTensor<AccDataType> s_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
ck_tile::HostTensor<AccDataType> p_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
Tensor<AccDataType> p_dropped_hp_host_ref(
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
Tensor<GemmDataType> p_lp_host_ref(
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
ck::index_t nr = nhead / nhead_k;
ck_tile::index_t nr = nhead / nhead_k;
// clang-format off
// permute
@@ -539,62 +548,68 @@ bool run(const ArgParser& arg_parser)
// reference
// S = scale * Q * K^T
reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, ck::identity{}, ck::identity{}, [&](AccDataType x) {
return scale * x;
}); // s_g_m_n = scale * q_g_m_k@k_g_n_k
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
if(use_bias)
{
// clang-format off
Tensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
// real_seqlen_k]
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
s_host_ref, bias_host_ref, s_host_ref);
ck_tile::
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
s_host_ref, bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask)
{
reference_batched_masking<AccDataType>(s_host_ref,
FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
}
else if(mask.type == mask_enum::window_generic)
{
reference_batched_masking<AccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
reference_batched_masking<AccDataType>(
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
else
reference_batched_masking<AccDataType>(
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, lse_host_ref);
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
if(p_drop > 0)
{
@@ -603,21 +618,21 @@ bool run(const ArgParser& arg_parser)
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
});
reference_batched_dropout(
ck_tile::reference_batched_dropout(
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
}
else
{
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
}
// O = P * V
reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
// clang-format off
@@ -652,28 +667,28 @@ bool run(const ArgParser& arg_parser)
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
for(ck::index_t wb = 0; wb < batch; ++wb)
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
Tensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
Tensor<AccDataType> ds_hp_host_ref(
ck_tile::HostTensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
ck_tile::HostTensor<AccDataType> ds_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
Tensor<GemmDataType> ds_lp_host_ref(
ck_tile::HostTensor<GemmDataType> ds_lp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
Tensor<AccDataType> dp_hp_host_ref(
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
Tensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
Tensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
Tensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
Tensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
// clang-format off
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
@@ -683,12 +698,12 @@ bool run(const ArgParser& arg_parser)
// dP = dO@V x Z w/ dropout
// dP = dO@V w/o dropout
auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
if(p_drop > 0)
{
reference_batched_dropout(
ck_tile::reference_batched_dropout(
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
}
@@ -699,56 +714,59 @@ bool run(const ArgParser& arg_parser)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
do_dot_o += ck::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
ck::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
}
self(idx_gmn) = ck::type_convert<AccDataType>(p_hp_host_refs[wb](idx_gmn) *
(dp_hp_host_ref(idx_gmn) - do_dot_o));
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
});
if(use_dbias)
{
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck::type_convert<BiasGradDataType>(self(idx));
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
});
}
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
ds_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
ck_tile::reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ds_lp_host_ref,
k_t_host_ref,
dq_host_ref,
ck::identity{},
ck::identity{},
[&scale](const AccDataType& x) { return scale * x; }); // dq_g_m_k = ds_g_m_n@k_g_k_n
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ds_t_lp_host_ref,
q_t_host_ref,
dk_host_ref,
ck::identity{},
ck::identity{},
[&scale](const AccDataType& x) { return scale * x; }); // dk_g_n_k = ds_g_n_m@q_g_k_m
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
Tensor<QGradDataType> dq_host_result({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
Tensor<KGradDataType> dk_host_result({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
Tensor<VGradDataType> dv_host_result({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
Tensor<BiasGradDataType> dbias_host_result(
ck_tile::HostTensor<QGradDataType> dq_host_result(
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_result(
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_result(
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
// clang-format off
@@ -764,36 +782,36 @@ bool run(const ArgParser& arg_parser)
if(use_dbias)
{
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2] + key_offset); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2] + key_offset); });
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool dq_cur_pass = ck::utils::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
rtol,
atol);
bool dk_cur_pass = ck::utils::check_err(dk_host_result,
dk_host_ref,
std::string("Error: KGrad Incorrect results!"),
rtol,
atol);
bool dv_cur_pass = ck::utils::check_err(dv_host_result,
dv_host_ref,
std::string("Error: VGrad Incorrect results!"),
rtol,
atol);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
rtol,
atol);
bool dk_cur_pass = ck_tile::check_err(dk_host_result,
dk_host_ref,
std::string("Error: KGrad Incorrect results!"),
rtol,
atol);
bool dv_cur_pass = ck_tile::check_err(dv_host_result,
dv_host_ref,
std::string("Error: VGrad Incorrect results!"),
rtol,
atol);
bool dbias_cur_pass = true;
if(use_dbias)
{
dbias_cur_pass = ck::utils::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
@@ -822,11 +840,11 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck::half_t>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck::bhalf_t>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
return -3;

View File

@@ -0,0 +1,346 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include <type_traits>
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias;
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
bool kHasBias_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bool has_bias;
bool has_dbias;
bool has_dropout;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@@ -65,7 +65,6 @@ BOOL_MAP = {
"f" : "false"
}
DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
@@ -469,7 +468,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@@ -507,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
if d == None:
continue
@@ -536,39 +535,574 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
}
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
}
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_epilogue_{F_idx} =
ck_tile::FmhaBwdEpilogue<ck_tile::FmhaBwdEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
}}
"""
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_dot_do_o_<dot_do_o_trait_>(s, a);
r += fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_>(s, a);
return r;
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str # true/false
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize
@property
def skcheck(self) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.skpad == 't':
return f'a.seqlen_k % {self.bn0} != 0'
else:
return f'a.seqlen_k % {self.bn0} == 0'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
else : return f'a.hdim_q % {self.bhdq} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_pool.keys():
self.dq_dk_dv_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass
class FmhaBwdDQDKDVTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
F_bhdq : int # q head_dim
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_occupancy : int # occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}"
@dataclass
class FmhaBwdDQDKDVKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_pipeline : str
mask_impl : str
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BOOL_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = BOOL_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property
def name(self) -> str:
def mask_name() -> str:
n = ''
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
return n
# TODO: we don't encode idx here
mn = mask_name()
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
if mn != '' : n += f'{mn}'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bhdq=self.F_tile.F_bhdq,
bhdv=self.F_tile.F_bhdv,
mask=self.F_mask,
bias=self.F_bias,
dbias=self.F_dbias,
dropout=self.F_dropout,
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
if direction == 'bwd':
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_kts_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_vr"]
}
else:
return None
else:
return None
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for pad
# support this in future
gen = list()
api_pool = FmhaBwdApiPool(mask_impl)
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
if (mode == "group") and (spad == "f" or skpad == "f"):
continue
if (bias == "f" and dbias == "t"):
continue
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
template<>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
}}
"""
@dataclass
class FmhaBwdOGradDotOKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_spad : str # true/false
F_dvpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}" +\
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
f"_o{self.F_occupancy}"
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
if d == None:
continue
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype,
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
F_occupancy=get_occupancy(dtype, hdim))
gen.append(k)
return gen
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_api(api_pool, output_dir)
if direction == 'fwd':
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
write_bwd_api(api_pool, output_dir)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
with file_path.open('a') as f:
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
if direction == 'fwd':
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen api for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction",
default='fwd',
choices=['fwd', 'bwd'],
required=False,
help="choose the direction of kernels(default: fwd)"
)
parser.add_argument(
"-o",
"--output_dir",
@@ -608,6 +1142,6 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.list_blobs is not None:
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
list_blobs(args.list_blobs, args.direction, args.filter, args.receipt, mask_impl=args.mask)
else:
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)
write_blobs(args.output_dir, args.direction, args.filter, args.receipt, mask_impl=args.mask)

View File

@@ -0,0 +1,21 @@
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
VALID=0
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
nhead=$((2048 / $hdim)) # follow fav2 setup
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
done
done
done

View File

@@ -0,0 +1,33 @@
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
for mode in 0 1 ; do
for bias in 0 1 ; do
for dbias in 0 1 ; do
for p_drop in 0.0 0.2; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done

View File

@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
@@ -37,6 +38,7 @@
#include "ck_tile/core/tensor/slice_tile.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/store_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/sweep_tile.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"

View File

@@ -765,21 +765,21 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
// buffer store ui16
__device__ void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -1658,7 +1658,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,

View File

@@ -0,0 +1,175 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
{
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
}
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{
bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template <typename X>
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
{
union U32BF162_ADDR
{
uint32_t* u32_a;
bf16x2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bf16x2_t bf162;
};
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
"wrong! not implemented");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<float>(x));
}
else if constexpr(N == 2)
{
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
}
}
else if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
return atomicAdd(p_dst, bit_cast<double>(x));
}
else if constexpr(N == 2)
{
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<int32_t>(x));
}
}
else if constexpr(std::is_same<T, uint32_t>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<uint32_t>(x));
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
}
else if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
x.template get_as<bf16x2_t>()[I1]);
}
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
{
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1)),
"wrong! not implemented");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<float>(x));
}
else if constexpr(N == 2)
{
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
}
}
else if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<double>(x));
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<int32_t>(x));
}
}
else if constexpr(std::is_same<T, uint32_t>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<uint32_t>(x));
}
}
}
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
}
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
{
if(is_valid_element)
{
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
}
}
}
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
}
else if(is_valid_element)
{
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
}
}

View File

@@ -16,7 +16,9 @@
namespace ck_tile {
template <typename BufferView_, typename TensorDesc_>
template <typename BufferView_,
typename TensorDesc_,
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
struct tensor_view
{
using buffer_view = remove_reference_t<BufferView_>;
@@ -24,6 +26,7 @@ struct tensor_view
using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
static constexpr auto DstInMemOp = DstInMemOp_;
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
@@ -140,6 +143,23 @@ struct tensor_view
x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename DataType,
typename... Lengths,
typename... Strides,
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
old_tensor_view.buf_, new_desc};
return tensor_view<typename OldTensorView::buffer_view,
remove_cvref_t<decltype(new_desc)>,
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
}
template <typename TensorView,

View File

@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
});
}
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.update(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.update(dstr_tensor);
}
} // namespace ck_tile

View File

@@ -4,4 +4,5 @@
#pragma once
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/custom_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AccDataType_, typename KGradDataType_, typename VGradDataType_>
struct FmhaBwdEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using KGradDataType = remove_cvref_t<KGradDataType_>;
using VGradDataType = remove_cvref_t<VGradDataType_>;
};
template <typename Problem_, typename Policy_ = void>
struct FmhaBwdEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename KGradDramWindowTmp,
typename VGradDramWindowTmp,
typename KGradAccTile,
typename VGradAccTile>
CK_TILE_DEVICE auto operator()(KGradDramWindowTmp& dk_dram_window_tmp,
VGradDramWindowTmp& dv_dram_window_tmp,
const KGradAccTile& dk_acc_tile,
const VGradAccTile& dv_acc_tile)
{
store_tile(dk_dram_window_tmp, cast_tile<KGradDataType>(dk_acc_tile));
store_tile(dv_dram_window_tmp, cast_tile<VGradDataType>(dv_acc_tile));
}
};
} // namespace ck_tile

View File

@@ -17,6 +17,19 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -18,10 +18,10 @@ struct FmhaFwdTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *

View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
struct BlockFmhaBwdOGradDotO
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kVHeaddim = Problem::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
template <typename ODramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename DDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
DDramBlockWindowTmp& d_dram_block_window_tmp,
float p_undrop) const
{
static_assert(
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize ==
OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
auto o_dram_window =
make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(),
o_dram_block_window_tmp.get_window_lengths(),
o_dram_block_window_tmp.get_window_origin(),
Policy::template MakePreODramTileDistribution<Problem>());
auto o = load_tile(o_dram_window);
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
do_dram_block_window_tmp.get_window_origin(),
Policy::template MakePreOGradDramTileDistribution<Problem>());
auto do_ = load_tile(do_dram_window);
// declare d
constexpr auto d_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
clear_tile(d); // Initialize D
constexpr auto o_spans = decltype(o)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
d(i_idx) +=
(type_convert<DDataType>(o[i_j_idx]) * type_convert<DDataType>(do_[i_j_idx]));
});
});
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
store_tile(d_dram_block_window_tmp, d);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -0,0 +1,821 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = true;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_kts_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType,
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto kt_dram_block_window = kt_dram_block_window_tmp;
auto kt_dram_window = make_tile_window(
kt_dram_block_window.get_bottom_tensor_view(),
kt_dram_block_window.get_window_lengths(),
kt_dram_block_window.get_window_origin(),
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
// load
auto kt_block_tile = load_tile(kt_dram_window);
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(kHasBias)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -0,0 +1,794 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(kHasBias)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -0,0 +1,665 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = true;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = true;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "qs_ks_vr_dos";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
// QT tile in LDS
auto qt_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
// OGradT tile in LDS
auto dot_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
block_sync_lds();
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(kHasBias)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds();
store_tile(do_lds_window, do_block_tile); // store the prefetch
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
get_slice_tile(dot_lds_window,
sequence<0, i_k1 * kK1>{},
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
block_sync_lds();
});
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc); // Initialize PGrad^T
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
get_slice_tile(do_lds_window,
sequence<0, i_k2 * kK2>{},
sequence<kM0, (i_k2 + 1) * kK2>{}),
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
});
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst);
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
get_slice_tile(qt_lds_window,
sequence<0, i_k3 * kK3>{},
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
block_sync_lds();
});
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
};
} // namespace ck_tile

View File

@@ -0,0 +1,91 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename GemmDataType_,
typename LSEDataType_,
typename AccDataType_,
typename DDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename ODataType_,
typename OGradDataType_,
typename QGradDataType_,
typename KGradDataType_,
typename VGradDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using GemmDataType = remove_cvref_t<GemmDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using KGradDataType = remove_cvref_t<KGradDataType_>;
using VGradDataType = remove_cvref_t<VGradDataType_>;
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasBias = Traits::kHasBias;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename ODataType_,
typename OGradDataType_,
typename DDataType_,
index_t kBlockSize_,
index_t kVHeaddim_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaBwdOGradDotOPipelineProblem
{
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kVHeaddim = kVHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -703,7 +703,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
template <typename Problem>
__host__ __device__ static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
if constexpr(AsyncCopyK)
{
@@ -716,7 +716,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
template <typename Problem>
__host__ __device__ static constexpr ck_tile::index_t GetSmemSizeDropout()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
{
if constexpr(Problem::kHasDropout)
{

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {
@@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
@@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
@@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
@@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// Construct C-Block-HostTensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBRegCRegV1DefaultPolicy>
struct BlockGemmASmemBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockTensorTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindowTmp& a_block_window_tmp,
const BBlockTensorTmp& b_block_tensor_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensorTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
// constrcut from B-block-tensor from B-Block-tensor-tmp
// FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent
// distribution
auto b_block_tensor =
make_static_distributed_tensor<typename BBlockTensorTmp::DataType>(b_block_dstr);
b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer();
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockWindowTmp, typename BBlockTensorTmp>
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
const BBlockTensorTmp& b_block_tensor_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmASmemBRegCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,56 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmASmemBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmASmemBRegCRegV1DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
}
};
} // namespace ck_tile

View File

@@ -526,9 +526,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec)
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
@@ -541,14 +541,14 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
auto c_vec = Impl{}(
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec)
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});