mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
FA bwd
This commit is contained in:
@@ -1,17 +1,29 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt
|
||||
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
# as current cmake list, otherwise will not figure out the dependency properly
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS)
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding tile_example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
# NOTE: this is dangerous since will change the whole kernel to flush denormals
|
||||
# WIP with compiler team for an exp2 intrinsic..., then remove this
|
||||
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
@@ -29,16 +49,21 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
@@ -9,35 +14,28 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor/tensor_view.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
#include "common/arg_parser.hpp"
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "reference/reference_batched_elementwise.hpp"
|
||||
#include "reference/reference_batched_gemm.hpp"
|
||||
#include "reference/reference_batched_masking.hpp"
|
||||
#include "reference/reference_batched_softmax.hpp"
|
||||
#include "reference/reference_batched_dropout.hpp"
|
||||
#include "utils.hpp"
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ArgParser arg_parser;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
@@ -69,11 +67,11 @@ auto create_args(int argc, char* argv[])
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, possitive is swa\n"
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, possitive is swa\n"
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)\n")
|
||||
"now)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
|
||||
.insert("seed",
|
||||
@@ -96,18 +94,18 @@ auto get_elimit(int /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck::make_tuple(rtol, atol);
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ArgParser& arg_parser)
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck::index_t batch = arg_parser.get_int("b");
|
||||
ck::index_t nhead = arg_parser.get_int("h");
|
||||
ck::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
if(nhead_k == 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
@@ -117,12 +115,12 @@ bool run(const ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
ck::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k == 0)
|
||||
seqlen_k = seqlen_q;
|
||||
ck::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v == 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
@@ -136,7 +134,7 @@ bool run(const ArgParser& arg_parser)
|
||||
|
||||
float scale = arg_parser.get_float("scale");
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck::math::sqrt(static_cast<float>(hdim_q));
|
||||
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
@@ -178,7 +176,7 @@ bool run(const ArgParser& arg_parser)
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
|
||||
StreamConfig stream_config{
|
||||
ck_tile::stream_config stream_config{
|
||||
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
@@ -209,7 +207,7 @@ bool run(const ArgParser& arg_parser)
|
||||
auto max_seqlen_k =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
{
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
@@ -224,11 +222,10 @@ bool run(const ArgParser& arg_parser)
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
flop += nhead *
|
||||
(3_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
|
||||
2_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
|
||||
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
|
||||
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
|
||||
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
|
||||
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
@@ -243,85 +240,97 @@ bool run(const ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
auto get_lengths = [&](bool permute,
|
||||
ck::index_t b /*batch*/,
|
||||
ck::index_t h /*nhead*/,
|
||||
ck::index_t s /*seqlen*/,
|
||||
ck::index_t d /*hdim*/) {
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/) {
|
||||
if(permute)
|
||||
return std::array<ck::index_t, 4>{b, h, s, d};
|
||||
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck::index_t, 4>{b, s, h, d};
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
};
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck::index_t shape_seqlen_q =
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck::index_t shape_seqlen_k =
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
|
||||
Tensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
Tensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
Tensor<VDataType> v_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
Tensor<BiasDataType> bias_host(
|
||||
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
Tensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
Tensor<LSEDataType> lse_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
Tensor<DDataType> d_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
Tensor<RandValOutputDataType> randval_host(
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
use_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1});
|
||||
Tensor<QGradDataType> dq_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
Tensor<KGradDataType> dk_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
|
||||
Tensor<VGradDataType> dv_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
|
||||
Tensor<OGradDataType> do_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
Tensor<BiasGradDataType> dbias_host(
|
||||
use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<QGradDataType> dq_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KGradDataType> dk_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<VGradDataType> dv_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
|
||||
ck_tile::HostTensor<OGradDataType> do_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host(
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck::utils::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck::utils::FillTrigValue<QDataType>{}(q_host);
|
||||
ck::utils::FillTrigValue<KDataType>{}(k_host);
|
||||
ck::utils::FillTrigValue<VDataType>{}(v_host);
|
||||
ck::utils::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck::utils::FillTrigValue<OGradDataType>{}(do_host);
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
|
||||
DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem d_buf(d_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem randval_buf(randval_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dq_buf(dq_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dk_buf(dk_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dv_buf(dv_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem do_buf(do_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dbias_buf(dbias_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -363,40 +372,39 @@ bool run(const ArgParser& arg_parser)
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
// setup stride_* arguments
|
||||
const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
|
||||
const ck::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_randval = (max_seqlen_k);
|
||||
const ck::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_dbias = (i_perm ? shape_seqlen_k : nhead * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
|
||||
const ck_tile::index_t stride_bias = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_lsed = max_seqlen_q;
|
||||
const ck::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * shape_seqlen_k : shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_bias = 0;
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
// setup batch_stride_* arguments
|
||||
const ck::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
|
||||
const ck::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck::index_t batch_stride_lsed = (nhead * max_seqlen_q);
|
||||
const ck::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck::index_t batch_stride_dbias = (nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_bias = 0;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
@@ -456,7 +464,7 @@ bool run(const ArgParser& arg_parser)
|
||||
batch_stride_dbias,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck::index_t>(mask.type),
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
s_randval,
|
||||
@@ -486,42 +494,43 @@ bool run(const ArgParser& arg_parser)
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::vector<Tensor<QDataType>> q_host_refs;
|
||||
std::vector<Tensor<KDataType>> k_host_refs;
|
||||
std::vector<Tensor<VDataType>> v_host_refs;
|
||||
std::vector<Tensor<ODataType>> o_host_refs;
|
||||
std::vector<Tensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<Tensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<Tensor<GemmDataType>> p_lp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<QDataType>> q_host_refs;
|
||||
std::vector<ck_tile::HostTensor<KDataType>> k_host_refs;
|
||||
std::vector<ck_tile::HostTensor<VDataType>> v_host_refs;
|
||||
std::vector<ck_tile::HostTensor<ODataType>> o_host_refs;
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
Tensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
|
||||
Tensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
|
||||
Tensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
|
||||
Tensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
|
||||
Tensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
|
||||
Tensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
|
||||
Tensor<AccDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
|
||||
Tensor<AccDataType> p_hp_host_ref(
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
|
||||
ck_tile::HostTensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> p_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
Tensor<AccDataType> p_dropped_hp_host_ref(
|
||||
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
Tensor<GemmDataType> p_lp_host_ref(
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
|
||||
ck::index_t nr = nhead / nhead_k;
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
@@ -539,62 +548,68 @@ bool run(const ArgParser& arg_parser)
|
||||
|
||||
// reference
|
||||
// S = scale * Q * K^T
|
||||
reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, ck::identity{}, ck::identity{}, [&](AccDataType x) {
|
||||
return scale * x;
|
||||
}); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
if(use_bias)
|
||||
{
|
||||
// clang-format off
|
||||
Tensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
ck_tile::
|
||||
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
reference_batched_masking<AccDataType>(s_host_ref,
|
||||
FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
reference_batched_masking<AccDataType>(
|
||||
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
reference_batched_masking<AccDataType>(
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
reference_batched_masking<AccDataType>(
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, lse_host_ref);
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
@@ -603,21 +618,21 @@ bool run(const ArgParser& arg_parser)
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
reference_batched_dropout(
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
@@ -652,28 +667,28 @@ bool run(const ArgParser& arg_parser)
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
Tensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
|
||||
Tensor<AccDataType> ds_hp_host_ref(
|
||||
ck_tile::HostTensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
|
||||
ck_tile::HostTensor<AccDataType> ds_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
|
||||
Tensor<GemmDataType> ds_lp_host_ref(
|
||||
ck_tile::HostTensor<GemmDataType> ds_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
|
||||
Tensor<AccDataType> dp_hp_host_ref(
|
||||
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
|
||||
Tensor<BiasGradDataType> dbias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
Tensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
Tensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
Tensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
|
||||
// clang-format off
|
||||
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
@@ -683,12 +698,12 @@ bool run(const ArgParser& arg_parser)
|
||||
// dP = dO@V x Z w/ dropout
|
||||
// dP = dO@V w/o dropout
|
||||
auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
|
||||
reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
|
||||
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
|
||||
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
reference_batched_dropout(
|
||||
ck_tile::reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
@@ -699,56 +714,59 @@ bool run(const ArgParser& arg_parser)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck::type_convert<AccDataType>(p_hp_host_refs[wb](idx_gmn) *
|
||||
(dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck::type_convert<BiasGradDataType>(self(idx));
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
|
||||
auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
ck_tile::reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
|
||||
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
|
||||
ds_lp_host_ref,
|
||||
k_t_host_ref,
|
||||
dq_host_ref,
|
||||
ck::identity{},
|
||||
ck::identity{},
|
||||
[&scale](const AccDataType& x) { return scale * x; }); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
|
||||
reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
|
||||
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
|
||||
ds_t_lp_host_ref,
|
||||
q_t_host_ref,
|
||||
dk_host_ref,
|
||||
ck::identity{},
|
||||
ck::identity{},
|
||||
[&scale](const AccDataType& x) { return scale * x; }); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
|
||||
Tensor<QGradDataType> dq_host_result({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
Tensor<KGradDataType> dk_host_result({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
Tensor<VGradDataType> dv_host_result({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
Tensor<BiasGradDataType> dbias_host_result(
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_result(
|
||||
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_result(
|
||||
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
|
||||
// clang-format off
|
||||
@@ -764,36 +782,36 @@ bool run(const ArgParser& arg_parser)
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2] + key_offset); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2] + key_offset); });
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
bool dq_cur_pass = ck::utils::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dk_cur_pass = ck::utils::check_err(dk_host_result,
|
||||
dk_host_ref,
|
||||
std::string("Error: KGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dv_cur_pass = ck::utils::check_err(dv_host_result,
|
||||
dv_host_ref,
|
||||
std::string("Error: VGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dk_cur_pass = ck_tile::check_err(dk_host_result,
|
||||
dk_host_ref,
|
||||
std::string("Error: KGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dv_cur_pass = ck_tile::check_err(dv_host_result,
|
||||
dv_host_ref,
|
||||
std::string("Error: VGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
bool dbias_cur_pass = true;
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_cur_pass = ck::utils::check_err(dbias_host_result,
|
||||
dbias_host_ref,
|
||||
std::string("Error: BiasGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
|
||||
dbias_host_ref,
|
||||
std::string("Error: BiasGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
|
||||
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
|
||||
@@ -822,11 +840,11 @@ int main(int argc, char* argv[])
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck::half_t>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck::bhalf_t>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
346
example/ck_tile/01_fmha/fmha_bwd.hpp
Normal file
346
example/ck_tile/01_fmha/fmha_bwd.hpp
Normal file
@@ -0,0 +1,346 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using GemmDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::half_t;
|
||||
using OGradDataType = ck_tile::half_t;
|
||||
using QGradDataType = ck_tile::half_t;
|
||||
using KGradDataType = ck_tile::half_t;
|
||||
using VGradDataType = ck_tile::half_t;
|
||||
using BiasGradDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using GemmDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using OGradDataType = ck_tile::bf16_t;
|
||||
using QGradDataType = ck_tile::bf16_t;
|
||||
using KGradDataType = ck_tile::bf16_t;
|
||||
using VGradDataType = ck_tile::bf16_t;
|
||||
using BiasGradDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_bwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* dq_ptr;
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
float scale;
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_lsed,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdOGradDotOKernel>
|
||||
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kHasDropout_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bool has_bias;
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
@@ -65,7 +65,6 @@ BOOL_MAP = {
|
||||
"f" : "false"
|
||||
}
|
||||
|
||||
DIRECTIONS = ["fwd"]
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -469,7 +468,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@@ -507,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
|
||||
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
@@ -536,39 +535,574 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
BWD_DQDKDV_PIPELINE_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
|
||||
}
|
||||
|
||||
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
|
||||
}
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "fmha_bwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
|
||||
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile_{F_idx}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
{F_dropout},
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
|
||||
fmha_bwd_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_bwd_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_epilogue_{F_idx} =
|
||||
ck_tile::FmhaBwdEpilogue<ck_tile::FmhaBwdEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
|
||||
fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API="""
|
||||
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
r = fmha_bwd_dot_do_o_<dot_do_o_trait_>(s, a);
|
||||
r += fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVApiTrait:
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str # true/false
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return f'a.seqlen_q % 256 == 0' # BlockSize
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
||||
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
||||
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
|
||||
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
|
||||
# Is it necessary to distinguish between K0~K4?
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
|
||||
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
|
||||
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
|
||||
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
|
||||
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
|
||||
F_bhdq : int # q head_dim
|
||||
F_bhdv : int # v head_dim
|
||||
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
|
||||
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
|
||||
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
|
||||
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
|
||||
F_wm : int # warp size along m (warp size)
|
||||
F_wn : int # warp size along n
|
||||
F_wk : int # warp size along k
|
||||
F_occupancy : int # occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}"
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_pipeline : str
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = BOOL_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def mask_name() -> str:
|
||||
n = ''
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
return n
|
||||
# TODO: we don't encode idx here
|
||||
mn = mask_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
if mn != '' : n += f'{mn}'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
|
||||
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bhdq=self.F_tile.F_bhdq,
|
||||
bhdv=self.F_tile.F_bhdv,
|
||||
mask=self.F_mask,
|
||||
bias=self.F_bias,
|
||||
dbias=self.F_dbias,
|
||||
dropout=self.F_dropout,
|
||||
spad=self.F_spad,
|
||||
skpad=self.F_skpad,
|
||||
dpad=self.F_dpad,
|
||||
dvpad=self.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
|
||||
if direction == 'bwd':
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
|
||||
"qs_ks_vr_dos"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"ks_kts_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"ks_vr"]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad
|
||||
# support this in future
|
||||
gen = list()
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if (bias == "f" and dbias == "t"):
|
||||
continue
|
||||
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
|
||||
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
|
||||
fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
template<>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_spad : str # true/false
|
||||
F_dvpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}" +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
f"_o{self.F_occupancy}"
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
gen = list()
|
||||
|
||||
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group" and spad == "f"):
|
||||
continue
|
||||
k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype,
|
||||
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
|
||||
F_occupancy=get_occupancy(dtype, hdim))
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_api(api_pool, output_dir)
|
||||
if direction == 'fwd':
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
else:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
if direction == 'fwd':
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
else:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen api for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction",
|
||||
default='fwd',
|
||||
choices=['fwd', 'bwd'],
|
||||
required=False,
|
||||
help="choose the direction of kernels(default: fwd)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
@@ -608,6 +1142,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
|
||||
list_blobs(args.list_blobs, args.direction, args.filter, args.receipt, mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, args.direction, args.filter, args.receipt, mask_impl=args.mask)
|
||||
|
||||
21
example/ck_tile/01_fmha/script/benchmark_bwd.sh
Normal file
21
example/ck_tile/01_fmha/script/benchmark_bwd.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_fmha_bwd
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
33
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
Normal file
33
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_fmha_bwd
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
for dbias in 0 1 ; do
|
||||
for p_drop in 0.0 0.2; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
@@ -37,6 +38,7 @@
|
||||
#include "ck_tile/core/tensor/slice_tile.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/store_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/tensor/sweep_tile.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
|
||||
@@ -765,21 +765,21 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
// buffer store ui16
|
||||
__device__ void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
__device__ void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -1658,7 +1658,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
|
||||
175
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
175
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
|
||||
{
|
||||
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add_bf16_t(a[0], b[0]);
|
||||
rtn[1] = add_bf16_t(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
union U32BF162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf16x2_t* bf162_a;
|
||||
};
|
||||
|
||||
union U32BF162
|
||||
{
|
||||
uint32_t u32;
|
||||
bf16x2_t bf162;
|
||||
};
|
||||
|
||||
U32BF162_ADDR dword_addr;
|
||||
U32BF162 cur_v;
|
||||
U32BF162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.bf162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return atomicAdd(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x2_t>()[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
}
|
||||
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BufferView_, typename TensorDesc_>
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
@@ -24,6 +26,7 @@ struct tensor_view
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
@@ -140,6 +143,23 @@ struct tensor_view
|
||||
x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
|
||||
old_tensor_view.buf_, new_desc};
|
||||
return tensor_view<typename OldTensorView::buffer_view,
|
||||
remove_cvref_t<decltype(new_desc)>,
|
||||
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
|
||||
@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
55
include/ck_tile/core/tensor/update_tile.hpp
Normal file
55
include/ck_tile/core/tensor/update_tile.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,4 +4,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/custom_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
41
include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp
Normal file
41
include/ck_tile/ops/epilogue/custom_2d_epilogue.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AccDataType_, typename KGradDataType_, typename VGradDataType_>
|
||||
struct FmhaBwdEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using KGradDataType = remove_cvref_t<KGradDataType_>;
|
||||
using VGradDataType = remove_cvref_t<VGradDataType_>;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct FmhaBwdEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename KGradDramWindowTmp,
|
||||
typename VGradDramWindowTmp,
|
||||
typename KGradAccTile,
|
||||
typename VGradAccTile>
|
||||
CK_TILE_DEVICE auto operator()(KGradDramWindowTmp& dk_dram_window_tmp,
|
||||
VGradDramWindowTmp& dv_dram_window_tmp,
|
||||
const KGradAccTile& dk_acc_tile,
|
||||
const VGradAccTile& dv_acc_tile)
|
||||
{
|
||||
store_tile(dk_dram_window_tmp, cast_tile<KGradDataType>(dk_acc_tile));
|
||||
store_tile(dv_dram_window_tmp, cast_tile<VGradDataType>(dv_acc_tile));
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -17,6 +17,19 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
1331
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
Normal file
1331
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaBwdTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
template <ck_tile::index_t kBlockSize>
|
||||
struct FmhaBwdOGradDotOTilePartitioner
|
||||
{
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -18,10 +18,10 @@ struct FmhaFwdTilePartitioner
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
|
||||
struct BlockFmhaBwdOGradDotO
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kVHeaddim = Problem::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
float p_undrop) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize ==
|
||||
OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
o_dram_block_window_tmp.get_window_lengths(),
|
||||
o_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePreODramTileDistribution<Problem>());
|
||||
|
||||
auto o = load_tile(o_dram_window);
|
||||
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
do_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePreOGradDramTileDistribution<Problem>());
|
||||
|
||||
auto do_ = load_tile(do_dram_window);
|
||||
|
||||
// declare d
|
||||
constexpr auto d_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
|
||||
|
||||
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
|
||||
|
||||
clear_tile(d); // Initialize D
|
||||
|
||||
constexpr auto o_spans = decltype(o)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
d(i_idx) +=
|
||||
(type_convert<DDataType>(o[i_j_idx]) * type_convert<DDataType>(do_[i_j_idx]));
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
|
||||
|
||||
store_tile(d_dram_block_window_tmp, d);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// These templates are not used here.
|
||||
using BlockFmhaBwdOGradDotODefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ false,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ false,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,821 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = true;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_kts_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType,
|
||||
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto kt_dram_block_window = kt_dram_block_window_tmp;
|
||||
|
||||
auto kt_dram_window = make_tile_window(
|
||||
kt_dram_block_window.get_bottom_tensor_view(),
|
||||
kt_dram_block_window.get_window_lengths(),
|
||||
kt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
|
||||
// load
|
||||
|
||||
auto kt_block_tile = load_tile(kt_dram_window);
|
||||
|
||||
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
|
||||
|
||||
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k & k^t located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ true,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,794 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,665 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = true;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "qs_ks_vr_dos";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
}
|
||||
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
block_sync_lds();
|
||||
store_tile(do_lds_window, do_block_tile); // store the prefetch
|
||||
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
get_slice_tile(dot_lds_window,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
get_slice_tile(do_lds_window,
|
||||
sequence<0, i_k2 * kK2>{},
|
||||
sequence<kM0, (i_k2 + 1) * kK2>{}),
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
block_sync_lds();
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
get_slice_tile(qt_lds_window,
|
||||
sequence<0, i_k3 * kK3>{},
|
||||
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, q & k & do located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ true,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KSKTSVR = 0,
|
||||
QSKSVROGradS,
|
||||
KSVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename GemmDataType_,
|
||||
typename LSEDataType_,
|
||||
typename AccDataType_,
|
||||
typename DDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename QGradDataType_,
|
||||
typename KGradDataType_,
|
||||
typename VGradDataType_,
|
||||
typename BiasGradDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using GemmDataType = remove_cvref_t<GemmDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using QGradDataType = remove_cvref_t<QGradDataType_>;
|
||||
using KGradDataType = remove_cvref_t<KGradDataType_>;
|
||||
using VGradDataType = remove_cvref_t<VGradDataType_>;
|
||||
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Traits::kHasBias;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename DDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kVHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
{
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
"kBlockSize should be divisible by get_warp_size()");
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kVHeaddim = kVHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -703,7 +703,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(AsyncCopyK)
|
||||
{
|
||||
@@ -716,7 +716,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr ck_tile::index_t GetSmemSizeDropout()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
|
||||
{
|
||||
if constexpr(Problem::kHasDropout)
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
@@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
@@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
{b_warp_window_tmp}};
|
||||
|
||||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// Construct C-Block-HostTensor
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
228
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
Normal file
228
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
Normal file
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBRegCRegV1DefaultPolicy>
|
||||
struct BlockGemmASmemBRegCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockTensorTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockTensorTmp& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
|
||||
|
||||
// constrcut from B-block-tensor from B-Block-tensor-tmp
|
||||
// FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto b_block_tensor =
|
||||
make_static_distributed_tensor<typename BBlockTensorTmp::DataType>(b_block_dstr);
|
||||
|
||||
b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
{b_warp_window_tmp}};
|
||||
|
||||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockTensorTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockTensorTmp& b_block_tensor_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_>
|
||||
struct BlockGemmASmemBRegCRegV1CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBRegCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmASmemBRegCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -526,9 +526,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
@@ -541,14 +541,14 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user