mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
FA fwd dropout
This commit is contained in:
833
example/ck_tile/01_fmha/fmha_bwd.cpp
Normal file
833
example/ck_tile/01_fmha/fmha_bwd.cpp
Normal file
@@ -0,0 +1,833 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor/tensor_view.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
#include "common/arg_parser.hpp"
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "reference/reference_batched_elementwise.hpp"
|
||||
#include "reference/reference_batched_gemm.hpp"
|
||||
#include "reference/reference_batched_masking.hpp"
|
||||
#include "reference/reference_batched_softmax.hpp"
|
||||
#include "reference/reference_batched_dropout.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"0",
|
||||
"num of head, for k/v, 0 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "0", "head dim for v, 0 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "add bias or not")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp16 or bf16")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, possitive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, possitive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)\n")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(int /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck::index_t batch = arg_parser.get_int("b");
|
||||
ck::index_t nhead = arg_parser.get_int("h");
|
||||
ck::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
if(nhead_k == 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
if(nhead % nhead_k != 0)
|
||||
{
|
||||
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
ck::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k == 0)
|
||||
seqlen_k = seqlen_q;
|
||||
ck::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v == 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
{
|
||||
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
|
||||
float scale = arg_parser.get_float("scale");
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck::math::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
if(use_dbias && !use_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when there is a bias" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
float rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
bool s_randval = false;
|
||||
if(p_drop > 0.0f && do_validation)
|
||||
{
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
int init_method = arg_parser.get_int("init");
|
||||
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
|
||||
StreamConfig stream_config{
|
||||
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataType>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using GemmDataType = typename TypeConfig::GemmDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using DDataType = typename TypeConfig::DDataType;
|
||||
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using OGradDataType = typename TypeConfig::OGradDataType;
|
||||
using QGradDataType = typename TypeConfig::QGradDataType;
|
||||
using KGradDataType = typename TypeConfig::KGradDataType;
|
||||
using VGradDataType = typename TypeConfig::VGradDataType;
|
||||
using BiasGradDataType = typename TypeConfig::BiasGradDataType;
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
{
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
if(max_seqlen_q < real_seqlen_q)
|
||||
{
|
||||
max_seqlen_q = real_seqlen_q;
|
||||
}
|
||||
|
||||
if(max_seqlen_k < real_seqlen_k)
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
flop += nhead *
|
||||
(3_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
|
||||
2_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
sizeof(VDataType) * real_seqlen_k * hdim_v +
|
||||
sizeof(ODataType) * real_seqlen_q * hdim_v +
|
||||
sizeof(OGradDataType) * real_seqlen_q * hdim_v +
|
||||
sizeof(QGradDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KGradDataType) * real_seqlen_k * hdim_q +
|
||||
sizeof(VGradDataType) * real_seqlen_k * hdim_v +
|
||||
sizeof(LSEDataType) * real_seqlen_q);
|
||||
}
|
||||
}
|
||||
|
||||
auto get_lengths = [&](bool permute,
|
||||
ck::index_t b /*batch*/,
|
||||
ck::index_t h /*nhead*/,
|
||||
ck::index_t s /*seqlen*/,
|
||||
ck::index_t d /*hdim*/) {
|
||||
if(permute)
|
||||
return std::array<ck::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck::index_t, 4>{b, s, h, d};
|
||||
};
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
|
||||
Tensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
Tensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
Tensor<VDataType> v_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
Tensor<BiasDataType> bias_host(
|
||||
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
Tensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
Tensor<LSEDataType> lse_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
Tensor<DDataType> d_host(std::array<ck::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
Tensor<RandValOutputDataType> randval_host(
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1});
|
||||
Tensor<QGradDataType> dq_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
Tensor<KGradDataType> dk_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
|
||||
Tensor<VGradDataType> dv_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
|
||||
Tensor<OGradDataType> do_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
Tensor<BiasGradDataType> dbias_host(
|
||||
use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck::utils::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck::utils::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck::utils::FillTrigValue<QDataType>{}(q_host);
|
||||
ck::utils::FillTrigValue<KDataType>{}(k_host);
|
||||
ck::utils::FillTrigValue<VDataType>{}(v_host);
|
||||
ck::utils::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck::utils::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
|
||||
DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem d_buf(d_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem randval_buf(randval_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dq_buf(dq_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dk_buf(dk_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dv_buf(dv_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem do_buf(do_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem dbias_buf(dbias_host.GetElementSpaceSizeInBytes());
|
||||
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
if (permute) return std::string("bhsd");
|
||||
else return std::string("bshd");
|
||||
};
|
||||
auto io_layout = [&](bool iperm_, bool operm_) {
|
||||
if (iperm_ == operm_) return layout_str(iperm_);
|
||||
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
|
||||
};
|
||||
// clang-format on
|
||||
const std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
|
||||
<< std::flush;
|
||||
|
||||
auto fmha_traits = fmha_bwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
use_bias,
|
||||
use_dbias,
|
||||
p_drop > 0.0f};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
// setup stride_* arguments
|
||||
const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
|
||||
const ck::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_randval = (max_seqlen_k);
|
||||
const ck::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck::index_t stride_dbias = (i_perm ? shape_seqlen_k : nhead * shape_seqlen_k);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck::index_t nhead_stride_lsed = max_seqlen_q;
|
||||
const ck::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * shape_seqlen_k : shape_seqlen_k);
|
||||
// setup batch_stride_* arguments
|
||||
const ck::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
|
||||
const ck::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck::index_t batch_stride_lsed = (nhead * max_seqlen_q);
|
||||
const ck::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck::index_t batch_stride_dbias = (nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
do_buf.GetDeviceBuffer(),
|
||||
d_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
dq_buf.GetDeviceBuffer(),
|
||||
dk_buf.GetDeviceBuffer(),
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_o,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
if(!do_validation)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::vector<Tensor<QDataType>> q_host_refs;
|
||||
std::vector<Tensor<KDataType>> k_host_refs;
|
||||
std::vector<Tensor<VDataType>> v_host_refs;
|
||||
std::vector<Tensor<ODataType>> o_host_refs;
|
||||
std::vector<Tensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<Tensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<Tensor<GemmDataType>> p_lp_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
Tensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
|
||||
Tensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
|
||||
Tensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
|
||||
Tensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
|
||||
Tensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
|
||||
Tensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
|
||||
Tensor<AccDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
|
||||
Tensor<AccDataType> p_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
Tensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
Tensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
|
||||
ck::index_t nr = nhead / nhead_k;
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
// S = scale * Q * K^T
|
||||
reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, ck::identity{}, ck::identity{}, [&](AccDataType x) {
|
||||
return scale * x;
|
||||
}); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
if(use_bias)
|
||||
{
|
||||
// clang-format off
|
||||
Tensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
reference_batched_masking<AccDataType>(s_host_ref,
|
||||
FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
reference_batched_masking<AccDataType>(
|
||||
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, lse_host_ref);
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
|
||||
// clang-format on
|
||||
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
k_host_refs.push_back(k_host_ref);
|
||||
v_host_refs.push_back(v_host_ref);
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
}
|
||||
}
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
|
||||
for(ck::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
Tensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
|
||||
Tensor<AccDataType> ds_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
|
||||
Tensor<GemmDataType> ds_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
|
||||
Tensor<AccDataType> dp_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
|
||||
Tensor<BiasGradDataType> dbias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
Tensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
Tensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
Tensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
|
||||
// clang-format off
|
||||
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// dP = dO@V x Z w/ dropout
|
||||
// dP = dO@V w/o dropout
|
||||
auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
|
||||
reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
|
||||
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck::type_convert<AccDataType>(p_hp_host_refs[wb](idx_gmn) *
|
||||
(dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
|
||||
auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
|
||||
ds_lp_host_ref,
|
||||
k_t_host_ref,
|
||||
dq_host_ref,
|
||||
ck::identity{},
|
||||
ck::identity{},
|
||||
[&scale](const AccDataType& x) { return scale * x; }); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
|
||||
reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
|
||||
ds_t_lp_host_ref,
|
||||
q_t_host_ref,
|
||||
dk_host_ref,
|
||||
ck::identity{},
|
||||
ck::identity{},
|
||||
[&scale](const AccDataType& x) { return scale * x; }); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
|
||||
Tensor<QGradDataType> dq_host_result({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
Tensor<KGradDataType> dk_host_result({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
Tensor<VGradDataType> dv_host_result({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
Tensor<BiasGradDataType> dbias_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
|
||||
if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
|
||||
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2] + key_offset); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2] + key_offset); });
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
bool dq_cur_pass = ck::utils::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dk_cur_pass = ck::utils::check_err(dk_host_result,
|
||||
dk_host_ref,
|
||||
std::string("Error: KGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dv_cur_pass = ck::utils::check_err(dv_host_result,
|
||||
dv_host_ref,
|
||||
std::string("Error: VGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
bool dbias_cur_pass = true;
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_cur_pass = ck::utils::check_err(dbias_host_result,
|
||||
dbias_host_ref,
|
||||
std::string("Error: BiasGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
|
||||
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
|
||||
{
|
||||
std::cerr << "mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck::bhalf_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -94,6 +94,9 @@ auto create_args(int argc, char* argv[])
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -111,20 +114,11 @@ auto get_elimit(int /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
auto get_elimit<ck_tile::bf16_t>(int /*init_method*/)
|
||||
{
|
||||
if(init_method == 0)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
double rtol = 3e-3;
|
||||
double atol = 3e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -207,9 +201,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
scale_o = range_p * range_v / range_o / dtype_max;
|
||||
}
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool s_randval = false;
|
||||
if(p_drop > 0.0f && do_validation)
|
||||
{
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
@@ -232,21 +240,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataType>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
using SaccDataType = typename TypeConfig::SaccDataType;
|
||||
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
||||
using PDataType = typename TypeConfig::PDataType;
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
using SaccDataType = typename TypeConfig::SaccDataType;
|
||||
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
||||
using PDataType = typename TypeConfig::PDataType;
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -258,6 +268,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_seqlen_q = real_seqlen_q;
|
||||
}
|
||||
|
||||
if(max_seqlen_k < real_seqlen_k)
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
|
||||
|
||||
@@ -303,12 +318,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
@@ -350,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -373,8 +393,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s
|
||||
<< ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
<< ", bias:" << use_bias << ", p_drop:" << p_drop << ", lse:" << lse
|
||||
<< ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
|
||||
auto fmha_traits = fmha_fwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -384,6 +404,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.type,
|
||||
use_bias,
|
||||
lse,
|
||||
p_drop > 0.0f,
|
||||
squant};
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
@@ -415,8 +436,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -428,20 +450,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
@@ -462,22 +487,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type)};
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
@@ -504,6 +535,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
float rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
@@ -629,6 +665,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
@@ -662,9 +709,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
|
||||
@@ -16,61 +16,65 @@ struct FmhaFwdTypeConfig;
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
@@ -87,6 +91,7 @@ struct fmha_fwd_args
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
@@ -107,22 +112,28 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
@@ -137,6 +148,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
@@ -144,6 +156,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
@@ -152,16 +165,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -169,12 +188,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
@@ -183,22 +204,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -221,6 +248,7 @@ template <ck_tile::index_t HDim_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
@@ -242,6 +270,7 @@ struct fmha_fwd_traits_
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
@@ -263,6 +292,7 @@ struct fmha_fwd_traits
|
||||
mask_enum mask_type;
|
||||
bool has_bias;
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -93,7 +93,9 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
@@ -105,6 +107,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
@@ -128,7 +131,7 @@ using fmha_kernel_{F_idx} =
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -173,9 +176,9 @@ MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -214,8 +217,9 @@ class FmhaFwdApiTrait:
|
||||
vlayout : str
|
||||
mask : str
|
||||
bias : str # true/false
|
||||
lse : str #
|
||||
squant : str #
|
||||
lse : str
|
||||
dropout : str
|
||||
squant : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
@@ -224,7 +228,7 @@ class FmhaFwdApiTrait:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -281,6 +285,7 @@ class FmhaFwdPipeline:
|
||||
F_dvpad : str #
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@@ -303,6 +308,7 @@ class FmhaFwdPipeline:
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
return n
|
||||
|
||||
@@ -332,7 +338,7 @@ class FmhaFwdApiPool:
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
@@ -346,7 +352,7 @@ class FmhaFwdApiPool:
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along qk seqlen
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
@@ -402,6 +408,7 @@ class FmhaFwdKernel:
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
@@ -435,6 +442,7 @@ class FmhaFwdKernel:
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
@@ -472,26 +480,26 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
@@ -18,15 +18,16 @@ for vlayout in "r" "c" ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for lse in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
for p_drop in 0.0 0.2; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -35,6 +36,7 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for perm in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
@@ -56,3 +56,4 @@
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -764,6 +764,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
// buffer store ui16
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -1334,7 +1356,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, uint16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(std::is_same<T, float>::value) // fp32
|
||||
@@ -1473,6 +1498,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint16_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16(bit_cast<uint16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast<uint16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast<uint16x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(
|
||||
src_thread_data.template get_as<uint16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(
|
||||
src_thread_data.template get_as<uint16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(uint16_t),
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using r_t = thread_buffer<int8_t, sizeof(T) * N>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// ui8
|
||||
// using uint8_t
|
||||
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
87
include/ck_tile/core/utility/philox_rand.hpp
Normal file
87
include/ck_tile/core/utility/philox_rand.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
__host__ __device__ inline philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
|
||||
uint2 key_ = seed;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 6; i++)
|
||||
{
|
||||
counter_ = philox_single_round(counter_, key_);
|
||||
key_.x += kPhilox10A;
|
||||
key_.y += kPhilox10B;
|
||||
}
|
||||
uint4 output = philox_single_round(counter_, key_);
|
||||
return output;
|
||||
}
|
||||
|
||||
__host__ __device__ void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
|
||||
out_tmp[0] = tmp_ph.x;
|
||||
out_tmp[1] = tmp_ph.y;
|
||||
out_tmp[2] = tmp_ph.z;
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
__host__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
tmp = static_cast<unsigned long long>(a) * b;
|
||||
res = reinterpret_cast<uint2*>(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& get_lengths() const { return mLens; }
|
||||
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
|
||||
const std::vector<std::size_t>& get_strides() const { return mStrides; }
|
||||
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
|
||||
for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
|
||||
{
|
||||
new_lengths[i] = a.get_lengths()[new2old[i]];
|
||||
new_strides[i] = a.GetStrides()[new2old[i]];
|
||||
new_strides[i] = a.get_strides()[new2old[i]];
|
||||
}
|
||||
|
||||
return HostTensorDescriptor(new_lengths, new_strides);
|
||||
@@ -327,7 +327,7 @@ struct HostTensor
|
||||
|
||||
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
|
||||
|
||||
decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
|
||||
decltype(auto) get_strides() const { return mDesc.get_strides(); }
|
||||
|
||||
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
|
||||
|
||||
@@ -481,6 +481,34 @@ struct HostTensor
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
HostTensor<T> Transpose(std::vector<size_t> axes = {}) const
|
||||
{
|
||||
if(axes.empty())
|
||||
{
|
||||
axes.resize(this->get_num_of_dimension());
|
||||
std::iota(axes.rbegin(), axes.rend(), 0);
|
||||
}
|
||||
if(axes.size() != mDesc.get_num_of_dimension())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"HostTensor::Transpose(): size of axes must match tensor dimension");
|
||||
}
|
||||
std::vector<size_t> tlengths, tstrides;
|
||||
for(const auto& axis : axes)
|
||||
{
|
||||
tlengths.push_back(get_lengths()[axis]);
|
||||
tstrides.push_back(get_strides()[axis]);
|
||||
}
|
||||
HostTensor<T> ret(*this);
|
||||
ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
|
||||
return ret;
|
||||
}
|
||||
|
||||
HostTensor<T> Transpose(std::vector<size_t> axes = {})
|
||||
{
|
||||
return const_cast<HostTensor<T> const*>(this)->Transpose(axes);
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
typename Data::iterator end() { return mData.end(); }
|
||||
|
||||
33
include/ck_tile/host/reference/reference_batched_dropout.hpp
Normal file
33
include/ck_tile/host/reference/reference_batched_dropout.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType, typename RandValOutputDataType>
|
||||
CK_TILE_HOST void reference_batched_dropout(HostTensor<DataType>& in_out_b_m_n,
|
||||
const HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
const uint8_t& p_undrop_in_uint8_t,
|
||||
const float scale)
|
||||
{
|
||||
const int N = in_out_b_m_n.mDesc.get_lengths()[2];
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float tmp = ck_tile::type_convert<float>(in_out_b_m_n(batch, m, n)) * scale;
|
||||
in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t
|
||||
? ck_tile::type_convert<DataType>(tmp)
|
||||
: DataType(0);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
|
||||
329
include/ck_tile/ops/fmha/block/block_dropout.hpp
Normal file
329
include/ck_tile/ops/fmha/block/block_dropout.hpp
Normal file
@@ -0,0 +1,329 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockDropout
|
||||
{
|
||||
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
|
||||
index_t i_head,
|
||||
index_t nheads,
|
||||
unsigned long long seed,
|
||||
unsigned long long offset,
|
||||
float rp_undrop_,
|
||||
uint8_t p_undrop_in_uint8_t_,
|
||||
bool is_store_randval_)
|
||||
: ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()),
|
||||
rp_undrop(rp_undrop_),
|
||||
p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
|
||||
is_store_randval(is_store_randval_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = WG::kN;
|
||||
constexpr index_t kN1 = 8;
|
||||
constexpr index_t kN0 = kNPerStep / kN1;
|
||||
|
||||
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
|
||||
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
|
||||
number<kN1>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
|
||||
randval_lds_block_desc_0,
|
||||
ck_tile::make_tuple(
|
||||
make_pass_through_transform(number<kMPerStep>{}),
|
||||
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
|
||||
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return randval_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
|
||||
constexpr auto randval_block_inner_part_dstr_encoding = []() {
|
||||
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::CDataType, float>)
|
||||
{
|
||||
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
randval_block_inner_part_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename PComputeDataType,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// randval tile in LDS
|
||||
auto randval_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
|
||||
|
||||
auto randval_lds_window = make_tile_window(
|
||||
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
|
||||
|
||||
// register distribute
|
||||
auto randval_dist_generated =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
|
||||
auto randval_lds_read_window =
|
||||
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
|
||||
randval_lds_window.get_window_lengths(),
|
||||
randval_lds_window.get_window_origin(),
|
||||
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
|
||||
|
||||
const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
|
||||
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
constexpr auto randval_dist_generated_spans =
|
||||
decltype(randval_dist_generated)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
|
||||
});
|
||||
});
|
||||
// save to LDS
|
||||
store_tile(randval_lds_window, randval_dist_generated);
|
||||
block_sync_lds();
|
||||
// read from LDS to register
|
||||
auto randval = load_tile(randval_lds_read_window);
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if(is_store_randval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval.kThreadElementSpaceSize == 16);
|
||||
|
||||
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
|
||||
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval(r_idx) = random_uint8_t[i_random_idx++];
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
|
||||
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx]
|
||||
: -p_compute[p_idx];
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if(is_store_randval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {kMPerStep, 0});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::philox ph;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -141,6 +141,36 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
@@ -167,7 +197,7 @@ struct GenericAttentionMask
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
@@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
@@ -289,7 +349,7 @@ struct SimplifiedGenericAttentionMask
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
@@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,11 +8,11 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q]
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -31,8 +31,10 @@ struct FmhaFwdKernel
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
@@ -43,6 +45,7 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
@@ -81,7 +84,8 @@ struct FmhaFwdKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) +
|
||||
(kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -108,6 +112,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
@@ -153,19 +158,44 @@ struct FmhaFwdKernel
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
ck_tile::index_t nhead_stride_lse = 0;
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs
|
||||
struct FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
}
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
bool is_store_randval = false;
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -178,7 +208,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -193,12 +224,14 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
@@ -207,22 +240,28 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -232,6 +271,7 @@ struct FmhaFwdKernel
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
@@ -250,6 +290,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -279,6 +320,15 @@ struct FmhaFwdKernel
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
kargs.batch_stride_randval = batch_stride_randval;
|
||||
kargs.is_store_randval = s_randval;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -289,6 +339,7 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
@@ -296,6 +347,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
@@ -304,16 +356,22 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -323,6 +381,7 @@ struct FmhaFwdKernel
|
||||
-1, //
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
@@ -341,6 +400,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -361,12 +421,21 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
kargs.is_store_randval = s_randval;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -398,12 +467,13 @@ struct FmhaFwdKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -431,7 +501,11 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
@@ -469,6 +543,11 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
@@ -642,6 +721,62 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// dropout
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 0;
|
||||
uint64_t drop_offset = 0;
|
||||
bool is_store_randval = false;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
rp_undrop = kargs.rp_undrop;
|
||||
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
|
||||
drop_seed = kargs.drop_seed;
|
||||
drop_offset = kargs.drop_offset;
|
||||
is_store_randval = kargs.is_store_randval;
|
||||
}
|
||||
BlockDropout dropout(i_batch,
|
||||
i_nhead,
|
||||
kargs.num_head_q,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
rp_undrop,
|
||||
p_undrop_in_uint8_t,
|
||||
is_store_randval);
|
||||
|
||||
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto randval_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
RandValOutputDataType* rand_val_ptr =
|
||||
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
|
||||
batch_offset_randval;
|
||||
|
||||
const auto randval_dram = [&]() {
|
||||
const auto randval_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
rand_val_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_randval, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
randval_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(randval_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
@@ -666,6 +801,7 @@ struct FmhaFwdKernel
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
@@ -673,7 +809,8 @@ struct FmhaFwdKernel
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -681,10 +818,12 @@ struct FmhaFwdKernel
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ template <typename QDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
@@ -23,19 +24,20 @@ template <typename QDataType_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Traits::kHasBias;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -13,19 +14,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -48,6 +50,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -105,6 +108,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
@@ -123,6 +127,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -130,7 +135,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -237,6 +243,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -450,6 +459,12 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -563,16 +578,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -582,6 +600,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
@@ -589,7 +608,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -14,19 +15,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -53,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -116,6 +119,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
@@ -134,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -141,7 +146,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -288,6 +294,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -532,6 +541,17 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
@@ -661,16 +681,19 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -680,6 +703,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
@@ -687,7 +711,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -13,19 +13,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -48,6 +49,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -105,18 +107,21 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported
|
||||
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
float descale_qk,
|
||||
float descale_sv,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
BlockDropout& /*dropout*/) const // not supported
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,19 +12,20 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
// TODO: assume Q is in register
|
||||
// TODO: assume K/V has same data type
|
||||
@@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
single_smem_size * max(NumPrefetchK, NumPrefetchV);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(AsyncCopyK)
|
||||
{
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr ck_tile::index_t GetSmemSizeDropout()
|
||||
{
|
||||
if constexpr(Problem::kHasDropout)
|
||||
{
|
||||
constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto config =
|
||||
decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = WG::kN;
|
||||
|
||||
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
|
||||
@@ -43,4 +43,53 @@ struct TileFmhaShape
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
typename Gemm2BlockWarps_,
|
||||
typename Gemm2WarpTile_,
|
||||
typename Gemm3BlockWarps_,
|
||||
typename Gemm3WarpTile_,
|
||||
typename Gemm4BlockWarps_,
|
||||
typename Gemm4WarpTile_>
|
||||
struct TileFmhaBwdShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
using Gemm2BlockWarps = remove_cvref_t<Gemm2BlockWarps_>;
|
||||
using Gemm2WarpTile = remove_cvref_t<Gemm2WarpTile_>;
|
||||
using Gemm3BlockWarps = remove_cvref_t<Gemm3BlockWarps_>;
|
||||
using Gemm3WarpTile = remove_cvref_t<Gemm3WarpTile_>;
|
||||
using Gemm4BlockWarps = remove_cvref_t<Gemm4BlockWarps_>;
|
||||
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 =
|
||||
BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
|
||||
static constexpr index_t kK1 =
|
||||
BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
|
||||
static constexpr index_t kK2 =
|
||||
BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
|
||||
static constexpr index_t kK3 =
|
||||
BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
|
||||
static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
|
||||
// K/K^T at once
|
||||
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
|
||||
// that need load V at once
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,7 +12,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kHasBias_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaTraits
|
||||
@@ -22,9 +24,21 @@ struct TileFmhaTraits
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
struct TileFmhaBwdOGradDotOTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,17 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Problem Description for BlockGemmARegBGmemCReg
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmARegBGmemCRegProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1
|
||||
|
||||
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
|
||||
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
|
||||
BlockGemmARegBSmemCRegProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
|
||||
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
|
||||
BlockGemmARegBSmemCRegV1DefaultPolicy>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Problem Description for BlockGemmASmemBSmemCRegV1
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmASmemBSmemCRegProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Problem Description for BlockGemmARegBSmemCReg
|
||||
// Problem Description for BlockGemm
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmARegBSmemCRegProblem
|
||||
struct BlockGemmProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
|
||||
|
||||
@@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
|
||||
2>>;
|
||||
@@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
|
||||
|
||||
@@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
|
||||
2>>;
|
||||
|
||||
@@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
Impl::kCMLane,
|
||||
SFactor,
|
||||
Impl::kCM1PerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user