diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp new file mode 100644 index 0000000000..ec52755f07 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -0,0 +1,833 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor/tensor_view.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/common_header.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "common/arg_parser.hpp" +#include "fmha_bwd.hpp" +#include "mask.hpp" +#include "reference/reference_batched_elementwise.hpp" +#include "reference/reference_batched_gemm.hpp" +#include "reference/reference_batched_masking.hpp" +#include "reference/reference_batched_softmax.hpp" +#include "reference/reference_batched_dropout.hpp" +#include "utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, possitive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, possitive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)\n") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck::make_tuple(rtol, atol); +} + +template +bool run(const ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck::index_t batch = arg_parser.get_int("b"); + ck::index_t nhead = arg_parser.get_int("h"); + ck::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck::index_t seqlen_q = arg_parser.get_int("s"); + ck::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck::index_t hdim_q = arg_parser.get_int("d"); + ck::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + if(hdim_q % 2 != 0 || hdim_v % 2 != 0) + { + std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; + return false; + } + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); + + bool use_bias = arg_parser.get_bool("bias"); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(use_dbias && !use_bias) + { + std::cerr << "dbias only exists when there is a bias" << std::endl; + return false; + } + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + StreamConfig stream_config{ + nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaBwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using GemmDataType = typename TypeConfig::GemmDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using AccDataType = typename TypeConfig::AccDataType; + using DDataType = typename TypeConfig::DDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using ODataType = typename TypeConfig::ODataType; + using OGradDataType = typename TypeConfig::OGradDataType; + using QGradDataType = typename TypeConfig::QGradDataType; + using KGradDataType = typename TypeConfig::KGradDataType; + using VGradDataType = typename TypeConfig::VGradDataType; + using BiasGradDataType = typename TypeConfig::BiasGradDataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + using namespace ck::literals; + + flop += nhead * + (3_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + 2_uz * 2_uz * real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * real_seqlen_k * hdim_v + + sizeof(ODataType) * real_seqlen_q * hdim_v + + sizeof(OGradDataType) * real_seqlen_q * hdim_v + + sizeof(QGradDataType) * real_seqlen_q * hdim_q + + sizeof(KGradDataType) * real_seqlen_k * hdim_q + + sizeof(VGradDataType) * real_seqlen_k * hdim_v + + sizeof(LSEDataType) * real_seqlen_q); + } + } + + auto get_lengths = [&](bool permute, + ck::index_t b /*batch*/, + ck::index_t h /*nhead*/, + ck::index_t s /*seqlen*/, + ck::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + Tensor v_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + Tensor bias_host( + use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + Tensor lse_host(std::array{batch, nhead, max_seqlen_q}); + Tensor d_host(std::array{batch, nhead, max_seqlen_q}); + Tensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + Tensor dq_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + Tensor dk_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + Tensor dv_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + Tensor do_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + Tensor dbias_host( + use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + if(init_method == 0) + { + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + } + else if(init_method == 1) + { + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 2) + { + ck::utils::FillTrigValue{}(q_host); + ck::utils::FillTrigValue{}(k_host); + ck::utils::FillTrigValue{}(v_host); + ck::utils::FillTrigValue{}(bias_host); + ck::utils::FillTrigValue{}(do_host); + } + + DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); + DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); + DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); + DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); + DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); + DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes()); + DeviceMem d_buf(d_host.GetElementSpaceSizeInBytes()); + DeviceMem randval_buf(randval_host.GetElementSpaceSizeInBytes()); + DeviceMem dq_buf(dq_host.GetElementSpaceSizeInBytes()); + DeviceMem dk_buf(dk_host.GetElementSpaceSizeInBytes()); + DeviceMem dv_buf(dv_host.GetElementSpaceSizeInBytes()); + DeviceMem do_buf(do_host.GetElementSpaceSizeInBytes()); + DeviceMem dbias_buf(dbias_host.GetElementSpaceSizeInBytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + do_buf.ToDevice(do_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask + << std::flush; + + auto fmha_traits = fmha_bwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + mask.type, + use_bias, + use_dbias, + p_drop > 0.0f}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck::index_t stride_randval = (max_seqlen_k); + const ck::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck::index_t stride_dbias = (i_perm ? shape_seqlen_k : nhead * shape_seqlen_k); + // setup nhead_stride_* arguments + const ck::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck::index_t nhead_stride_lsed = max_seqlen_q; + const ck::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * shape_seqlen_k : shape_seqlen_k); + // setup batch_stride_* arguments + const ck::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck::index_t batch_stride_dbias = (nhead * shape_seqlen_q * shape_seqlen_k); + + return fmha_bwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + do_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + dq_buf.GetDeviceBuffer(), + dk_buf.GetDeviceBuffer(), + dv_buf.GetDeviceBuffer(), + dbias_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + stride_randval, + stride_do, + stride_dk, + stride_dv, + stride_dbias, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + nhead_stride_randval, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dbias, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + batch_stride_randval, + batch_stride_do, + batch_stride_lsed, + batch_stride_dk, + batch_stride_dv, + batch_stride_dbias, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + p_undrop, + s_randval, + {drop_seed, drop_offset}}; + }(); + + float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + bool pass = true; + + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; + + randval_buf.FromDevice(randval_host.data()); + + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + Tensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + Tensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + Tensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + Tensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + Tensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + Tensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + + // reference + // S = scale * Q * K^T + reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, ck::identity{}, ck::identity{}, [&](AccDataType x) { + return scale * x; + }); // s_g_m_n = scale * q_g_m_k@k_g_n_k + + if(use_bias) + { + // clang-format off + Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + reference_batched_masking(s_host_ref, + FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + reference_batched_masking( + s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + reference_batched_masking( + s_host_ref, + ck::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + reference_batched_masking( + s_host_ref, + ck::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + reference_batched_softmax( + s_host_ref, p_hp_host_ref, lse_host_ref); + + if(p_drop > 0) + { + p_hp_host_ref.ForEach( + [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck::type_convert(self(idx)); + }); + } + else + { + p_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck::type_convert(self(idx)); + }); + } + + // O = P * V + reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + + // clang-format off + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); + // clang-format on + + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } + } + + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + + fmha_bwd(fmha_traits, fmha_args, stream_config); + + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); + + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + Tensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + Tensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + Tensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + Tensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + Tensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + Tensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + Tensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + Tensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + + // clang-format off + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + + if(p_drop > 0) + { + reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } + + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + auto idx_gmo = idx_gmn; + idx_gmo[2] = o; + do_dot_o += ck::type_convert(do_host_ref(idx_gmo)) * + ck::type_convert(o_host_refs[wb](idx_gmo)); + } + self(idx_gmn) = ck::type_convert(p_hp_host_refs[wb](idx_gmn) * + (dp_hp_host_ref(idx_gmn) - do_dot_o)); + }); + + if(use_dbias) + { + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + dbias_host_ref(idx) = ck::type_convert(self(idx)); + }); + } + + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + ds_lp_host_ref(idx) = ck::type_convert(self(idx)); + }); + + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck::identity{}, + ck::identity{}, + [&scale](const AccDataType& x) { return scale * x; }); // dq_g_m_k = ds_g_m_n@k_g_k_n + + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck::identity{}, + ck::identity{}, + [&scale](const AccDataType& x) { return scale * x; }); // dk_g_n_k = ds_g_n_m@q_g_k_m + + Tensor dq_host_result({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + Tensor dk_host_result({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + Tensor dv_host_result({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + Tensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + + // clang-format off + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2] + key_offset); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2] + key_offset); }); + } + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool dq_cur_pass = ck::utils::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck::utils::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck::utils::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); + + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = ck::utils::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8ca4ff9337..158701e00d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -94,6 +94,9 @@ auto create_args(int argc, char* argv[]) "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -111,20 +114,11 @@ auto get_elimit(int /*init_method*/) } template <> -auto get_elimit(int init_method) +auto get_elimit(int /*init_method*/) { - if(init_method == 0) - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else - { - double rtol = 3e-3; - double atol = 3e-3; - return ck_tile::make_tuple(rtol, atol); - } + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); } template <> @@ -207,9 +201,23 @@ bool run(const ck_tile::ArgParser& arg_parser) scale_o = range_p * range_v / range_o / dtype_max; } - std::string vlayout = arg_parser.get_str("vlayout"); - bool use_bias = arg_parser.get_bool("bias"); - bool lse = arg_parser.get_bool("lse"); + std::string vlayout = arg_parser.get_str("vlayout"); + bool use_bias = arg_parser.get_bool("bias"); + bool lse = arg_parser.get_bool("lse"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); @@ -232,21 +240,23 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = FmhaFwdTypeConfig; - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -258,6 +268,11 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_q = real_seqlen_q; } + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); @@ -303,12 +318,16 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + if(init_method == 0) { ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); @@ -350,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -373,8 +393,8 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s - << ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", bias:" << use_bias << ", p_drop:" << p_drop << ", lse:" << lse + << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -384,6 +404,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, use_bias, lse, + p_drop > 0.0f, squant}; auto p_compute_element_func = [&]() { @@ -415,8 +436,9 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -428,20 +450,23 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); - const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -462,22 +487,28 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_k, stride_v, stride_bias, + stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_bias, + batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, - static_cast(mask.type)}; + static_cast(mask.type), + p_drop, + s_randval, + {drop_seed, drop_offset}}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); @@ -504,6 +535,11 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; bool pass = true; @@ -629,6 +665,17 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, p_host_ref, p_compute_element_func); } + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, @@ -662,9 +709,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b, idx[0], idx[1] + query_offset); - }); + lse_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); bool lse_pass = ck_tile::check_err(lse_host_result, lse_host_ref, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9a82ab6b79..43246f1ace 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -16,61 +16,65 @@ struct FmhaFwdTypeConfig; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; }; struct FmhaMasks @@ -87,6 +91,7 @@ struct fmha_fwd_args const void* k_ptr; const void* v_ptr; const void* bias_ptr; + void* rand_val_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -107,22 +112,28 @@ struct fmha_fwd_args ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; + ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; }; template @@ -137,6 +148,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, @@ -144,6 +156,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqlen_k_ptr, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -152,16 +165,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_lse, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -169,12 +188,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -183,22 +204,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, + args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); @@ -221,6 +248,7 @@ template ; static constexpr bool kHasBias = kHasBias_; static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; @@ -263,6 +292,7 @@ struct fmha_fwd_traits mask_enum mask_type; bool has_bias; bool has_lse; + bool has_dropout; bool do_fp8_static_quant; // TODO: padding check is inside this api }; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 56d699e5fe..f41d3d3fff 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -93,7 +93,9 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_dpad}, {F_dvpad}, {F_bias}, + false, {F_lse}, + {F_dropout}, {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; @@ -105,6 +107,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, @@ -128,7 +131,7 @@ using fmha_kernel_{F_idx} = fmha_epilogue_{F_idx}>; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -173,9 +176,9 @@ MASK_SIMPLIFIED_CHECK_MAP = { "s_mask" : "t.mask_type != mask_enum::no_mask", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -214,8 +217,9 @@ class FmhaFwdApiTrait: vlayout : str mask : str bias : str # true/false - lse : str # - squant : str # + lse : str + dropout : str + squant : str spad : str skpad : str dpad : str @@ -224,7 +228,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -281,6 +285,7 @@ class FmhaFwdPipeline: F_dvpad : str # F_bias : str # true/false F_lse : str # + F_dropout : str # F_squant : str # F_mask : str # value from MASK_MAP @@ -303,6 +308,7 @@ class FmhaFwdPipeline: else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' return n @@ -332,7 +338,7 @@ class FmhaFwdApiPool: if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -346,7 +352,7 @@ class FmhaFwdApiPool: @dataclass class FmhaFwdTileSize: F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along qk seqlen + F_bn0 : int # tile size along k seqlen F_bk0 : int # tile size along qk gemm unroll F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll @@ -402,6 +408,7 @@ class FmhaFwdKernel: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BOOL_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -435,6 +442,7 @@ class FmhaFwdKernel: mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, @@ -472,26 +480,26 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse kernels + # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) else: assert False return pipelines diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh similarity index 100% rename from example/ck_tile/01_fmha/script/benchmark.sh rename to example/ck_tile/01_fmha/script/benchmark_fwd.sh diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh similarity index 55% rename from example/ck_tile/01_fmha/script/smoke_test.sh rename to example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 4dd5c2ae12..a02b227a0b 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -18,15 +18,16 @@ for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in 0 1 ; do +for p_drop in 0.0 0.2; do -# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done @@ -35,6 +36,7 @@ done done done done +done for perm in 0 1 ; do for bias in 0 1 ; do diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb19c9154b..bdf8d79d34 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -56,3 +56,4 @@ #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/unary_element_function.hpp" +#include "ck_tile/core/utility/philox_rand.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 53f42a7421..39d755f0d9 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -764,6 +764,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); +// buffer store ui16 +__device__ void +llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, @@ -1334,7 +1356,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 @@ -1473,6 +1498,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d static_cast(coherence)); } } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(uint16_t), + static_cast(coherence)); + } + } else { using r_t = thread_buffer; diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 85d9be1c94..c23c12f295 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); +// ui8 +// using uint8_t +using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); +using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); +using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); +using uint8x16_t = uint8_t __attribute((ext_vector_type(16))); +using uint8x32_t = uint8_t __attribute((ext_vector_type(32))); +using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); + #if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index c12ad883d9..2efc657013 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp new file mode 100644 index 0000000000..d68381e369 --- /dev/null +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh +class philox +{ + public: + __host__ __device__ inline philox(unsigned long long seed_, unsigned long long offset_) + : seed(reinterpret_cast(seed_)) + { + + ull2* tmp = reinterpret_cast(&counter); + tmp->x = offset_; + } + + __host__ __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence) const + { + + uint4 counter_ = counter; + ull2* tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + + uint2 key_ = seed; +// 7-round philox +#pragma unroll + for(int i = 0; i < 6; i++) + { + counter_ = philox_single_round(counter_, key_); + key_.x += kPhilox10A; + key_.y += kPhilox10B; + } + uint4 output = philox_single_round(counter_, key_); + return output; + } + + __host__ __device__ void get_random_16x8(uint8_t* out, + const unsigned long long subsequence) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32_t* out_tmp = reinterpret_cast(&out[0]); + + out_tmp[0] = tmp_ph.x; + out_tmp[1] = tmp_ph.y; + out_tmp[2] = tmp_ph.z; + out_tmp[3] = tmp_ph.w; + } + + private: + struct ull2 + { + uint64_t x; + uint64_t y; + }; + uint4 counter; + const uint2 seed; + + __host__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) const + { + uint2* res; + unsigned long long tmp; + tmp = static_cast(a) * b; + res = reinterpret_cast(&tmp); + return *res; + } + + __host__ __device__ inline uint4 philox_single_round(const uint4 ctr, const uint2 key) const + { + + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0c4a778226..62fce34d1a 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -15,6 +15,7 @@ #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index cd0dc38259..bb60fc8172 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,7 +156,7 @@ struct HostTensorDescriptor } const std::vector& get_lengths() const { return mLens; } - const std::vector& GetStrides() const { return mStrides; } + const std::vector& get_strides() const { return mStrides; } template std::size_t GetOffsetFromMultiIndex(Is... is) const @@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old for(std::size_t i = 0; i < a.get_num_of_dimension(); i++) { new_lengths[i] = a.get_lengths()[new2old[i]]; - new_strides[i] = a.GetStrides()[new2old[i]]; + new_strides[i] = a.get_strides()[new2old[i]]; } return HostTensorDescriptor(new_lengths, new_strides); @@ -327,7 +327,7 @@ struct HostTensor decltype(auto) get_lengths() const { return mDesc.get_lengths(); } - decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + decltype(auto) get_strides() const { return mDesc.get_strides(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } @@ -481,6 +481,34 @@ struct HostTensor return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } + HostTensor Transpose(std::vector axes = {}) const + { + if(axes.empty()) + { + axes.resize(this->get_num_of_dimension()); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if(axes.size() != mDesc.get_num_of_dimension()) + { + throw std::runtime_error( + "HostTensor::Transpose(): size of axes must match tensor dimension"); + } + std::vector tlengths, tstrides; + for(const auto& axis : axes) + { + tlengths.push_back(get_lengths()[axis]); + tstrides.push_back(get_strides()[axis]); + } + HostTensor ret(*this); + ret.mDesc = HostTensorDescriptor(tlengths, tstrides); + return ret; + } + + HostTensor Transpose(std::vector axes = {}) + { + return const_cast const*>(this)->Transpose(axes); + } + typename Data::iterator begin() { return mData.begin(); } typename Data::iterator end() { return mData.end(); } diff --git a/include/ck_tile/host/reference/reference_batched_dropout.hpp b/include/ck_tile/host/reference/reference_batched_dropout.hpp new file mode 100644 index 0000000000..242101bf4d --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_dropout.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_dropout(HostTensor& in_out_b_m_n, + const HostTensor& randval_b_m_n, + const uint8_t& p_undrop_in_uint8_t, + const float scale) +{ + const int N = in_out_b_m_n.mDesc.get_lengths()[2]; + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + float tmp = ck_tile::type_convert(in_out_b_m_n(batch, m, n)) * scale; + in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t + ? ck_tile::type_convert(tmp) + : DataType(0); + } + }; + + make_ParallelTensorFunctor( + f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index c567e63ddf..9d08a55bf6 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp new file mode 100644 index 0000000000..1f0fe2bd64 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + else + { + return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + }(); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + + auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx1 = + tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval.kThreadElementSpaceSize == 16); + + const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + int block_row_start = (start_m0_idx / WG::kM) + i_m0; + int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + randval(r_idx) = random_uint8_t[i_random_idx++]; + constexpr auto p_idx0 = + tile_distributed_index{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + ck_tile::philox ph; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 7fb1c19b5f..fff0a9f690 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,6 +141,36 @@ struct GenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -167,7 +197,7 @@ struct GenericAttentionMask // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -289,7 +349,7 @@ struct SimplifiedGenericAttentionMask // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size, { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0732fd2ce2..c60ea432bd 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,11 +8,11 @@ #include #include -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] -// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] namespace ck_tile { @@ -31,8 +31,10 @@ struct FmhaFwdKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; - using ODataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; @@ -43,6 +45,7 @@ struct FmhaFwdKernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasBias = FmhaPipeline::kHasBias; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -81,7 +84,8 @@ struct FmhaFwdKernel "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -108,6 +112,7 @@ struct FmhaFwdKernel ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; + ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; @@ -153,19 +158,44 @@ struct FmhaFwdKernel { void* lse_ptr = nullptr; ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + struct FmhaFwdCommonDropoutKargs { - ck_tile::index_t batch_stride_lse = 0; + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; }; struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -178,7 +208,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -193,12 +224,14 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -207,22 +240,28 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -232,6 +271,7 @@ struct FmhaFwdKernel seqlen_k, hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -250,6 +290,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v, @@ -279,6 +320,15 @@ struct FmhaFwdKernel kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -289,6 +339,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -296,6 +347,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -304,16 +356,22 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -323,6 +381,7 @@ struct FmhaFwdKernel -1, // hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -341,6 +400,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -361,12 +421,21 @@ struct FmhaFwdKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -398,12 +467,13 @@ struct FmhaFwdKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -431,7 +501,11 @@ struct FmhaFwdKernel } if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; } batch_offset_o = query_start * kargs.stride_o; @@ -469,6 +543,11 @@ struct FmhaFwdKernel { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -642,6 +721,62 @@ struct FmhaFwdKernel } }(); + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -666,6 +801,7 @@ struct FmhaFwdKernel identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func + randval_dram_window, lse_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func @@ -673,7 +809,8 @@ struct FmhaFwdKernel composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func mask, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } else { @@ -681,10 +818,12 @@ struct FmhaFwdKernel k_dram_window, v_dram_window, bias_dram_window, + randval_dram_window, lse_dram_window, mask, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9d27b2df68..624491de8f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -13,6 +13,7 @@ template struct BlockFmhaPipelineProblem { - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; @@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kHasBias = Traits::kHasBias; static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e239bb916..6e1768329f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -13,19 +14,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -48,6 +50,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -105,6 +108,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -123,6 +127,7 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -130,7 +135,8 @@ struct BlockFmhaPipelineQRKSVS const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -237,6 +243,9 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -450,6 +459,12 @@ struct BlockFmhaPipelineQRKSVS }); }); + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + block_sync_lds(); if constexpr(std::is_same_v) { @@ -563,16 +578,19 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -582,6 +600,7 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -589,7 +608,8 @@ struct BlockFmhaPipelineQRKSVS identity{}, mask, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 0573b50d04..de75313130 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -14,19 +15,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVSAsync { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -53,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -116,6 +119,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -134,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -141,7 +146,8 @@ struct BlockFmhaPipelineQRKSVSAsync const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -288,6 +294,9 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -532,6 +541,17 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -661,16 +681,19 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -680,6 +703,7 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -687,7 +711,8 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, mask, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 0e59ee6fe0..c5f41f4323 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,19 +13,20 @@ namespace ck_tile { template struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -48,6 +49,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -105,18 +107,21 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, float scale_s, float descale_qk, float descale_sv, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& /*dropout*/) const // not supported { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 677c05769c..995c250ade 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,19 +12,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQSKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4fda6f008f..7b2940bd6b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { // TODO: assume Q is in register // TODO: assume K/V has same data type @@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + __host__ __device__ static constexpr ck_tile::index_t GetSmemSize() + { + if constexpr(AsyncCopyK) + { + return GetSmemSizeKV() + GetSmemSizeDropout(); + } + else + { + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + } + } + + template + __host__ __device__ static constexpr ck_tile::index_t GetSmemSizeDropout() + { + if constexpr(Problem::kHasDropout) + { + constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm(); + constexpr auto config = + decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + + return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); + } + else + { + return 0; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index d8a290b09c..64a61e94d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -43,4 +43,53 @@ struct TileFmhaShape ck_tile::tensor_layout::gemm::ColumnMajor>; }; +template +struct TileFmhaBwdShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + using Gemm2BlockWarps = remove_cvref_t; + using Gemm2WarpTile = remove_cvref_t; + using Gemm3BlockWarps = remove_cvref_t; + using Gemm3WarpTile = remove_cvref_t; + using Gemm4BlockWarps = remove_cvref_t; + using Gemm4WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + + static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) && + NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{})); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = + BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll + static constexpr index_t kK1 = + BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll + static constexpr index_t kK2 = + BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll + static constexpr index_t kK3 = + BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll + static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or + // K/K^T at once + static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline + // that need load V at once +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 137f4ddd81..c2b2ba1f78 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,7 +12,9 @@ template struct TileFmhaTraits @@ -22,9 +24,21 @@ struct TileFmhaTraits static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr bool kHasBias = kHasBias_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaBwdOGradDotOTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c7ebcf9606..c97073aaf5 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -3,17 +3,15 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp deleted file mode 100644 index 1053c751ad..0000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -// Problem Description for BlockGemmARegBGmemCReg -template -struct BlockGemmARegBGmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index 7799bbe918..f097790ae6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem, + BlockGemmProblem, BlockGemmARegBSmemCRegV1DefaultPolicy>; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp index 4156398bd3..0a17b05353 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index aac9c4f552..80dda9f17d 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp index 779113d96a..f998c67c95 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 8073989264..9b10d435b6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 405d7f1258..4a82702c1f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp index 8bcd04b7b0..20dcf2c270 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp index c17385b8e5..e90500c28c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp deleted file mode 100644 index ed772891a4..0000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -// Problem Description for BlockGemmASmemBSmemCRegV1 -template -struct BlockGemmASmemBSmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 40da16d820..ac45221709 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 319711088f..2436457ec1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index fbb957727d..f798d6e815 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp similarity index 88% rename from include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp rename to include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index 7a0390a8a2..d8f66c81ca 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -7,13 +7,13 @@ namespace ck_tile { -// Problem Description for BlockGemmARegBSmemCReg +// Problem Description for BlockGemm template -struct BlockGemmARegBSmemCRegProblem +struct BlockGemmProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index dfc63f04c6..5b4419b79f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; -using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; @@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; -using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 71c59bbd17..f2e586f794 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB } }; +template +struct WarpGemmAtrributeMfmaIterateK_SwizzleA +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + + return c_vec; + } +}; + } // namespace ck_tile