mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout
Add comments with dropout implementation details
Fix performance regression of fwd+dropout
* Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
* "scalarize" seed and offset, they may come either from kernel args or from device memory
(presumably loaded with vector loads).
These changes help the compiler to procude more optimal code and reduce register spilling.
Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding
Use code based on BlockDropout in BlockDropoutBwd
Refactor BlockDropout (fwd)
Implement BlockDropout (fwd) for WMMA
Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
this version supports 16x16 tiles.
If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
to BlockDropoutBwd.
Implement BlockDropoutBwd for WMMA
Remove MakeRandValLds* functions unused in BlockDropoutBwd
Remove unused Run overload from BlockDropoutBwd
* Fix regression with philox seed and offset when they exceed 32-bit int
__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.
* Add F32 MFMA warp gemms
* Support f32 in fwd FMHA
* Implement transpose_vectors for 4-byte types (float)
* Fix unexpected implicit f32->uint32 cast in buffer_store<4>
__builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint).
mbuf_t types in other buffer_store<> are changed for consistency.
* Support F32 in bwd FMHA
hdim = 256 is disabled for now because it uses too much memory on gfx90a
* Support Headdim = 48 (divisible by 16) in fwd
* Add fp32-specific receipts (800 and 801)
* Tune fwd tiles
* Tune bwd tiles
* Use small tiles only for small seqlen_q
* Fix after rebasing
* Fix selection of a fallback tile based on bm0
The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.
* Remove constraints and adjust filtering for fp32
Custom constraints are no longer needed because now the smallest tile
is selected automtically based on seqlen_q.
Filters related to qr_async_trload disabled valid fp32 tiles.
* Add fp32 tests
* Make splitkv and appendkv compile for fp32 only
There are no instances yet, but API still must compile when only fp32 is
requested.
* Remove unimportant f32 instances
* Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS
* Replace magic numbers with a constant, improve comments for dropout
* Update changelog
* Fix condition that dq_acc must be set to zero when mask is used
The change was introduced in #2799
* Replace warp_uniform with recently added amd_wave_read_first_lane
* Add hdim = 96 and 192 to fwd
1790 lines
79 KiB
C++
1790 lines
79 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/host.hpp"
|
|
#include "ck_tile/ref/naive_attention.hpp"
|
|
#include "fmha_fwd.hpp"
|
|
#include "utils.hpp"
|
|
#include "ck_tile/utility/json_dump.hpp"
|
|
|
|
#include <array>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <cmath>
|
|
#include <numeric>
|
|
#include <ostream>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
|
|
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
|
|
#endif
|
|
|
|
enum class fwd_result
|
|
{
|
|
success,
|
|
failure,
|
|
invalid_args,
|
|
no_instance,
|
|
};
|
|
|
|
// different threshold for different dtype
|
|
template <typename DataTypeConfig>
|
|
auto get_elimit(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-3;
|
|
double atol = 1e-3;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp32>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-5;
|
|
double atol = 1e-5;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1e-2;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
|
{
|
|
using TypeConfig = FmhaFwdTypeConfig<FmhaFwdFp8>;
|
|
using ODataType = typename TypeConfig::ODataType;
|
|
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
|
double rtol = 0;
|
|
double atol = 16 * (o_dtype_max > 240 ? 2 : 1);
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1.8e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1.8e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
|
{
|
|
// If we have enough to almost fill the SMs, then just use 1 split
|
|
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
|
{
|
|
return 1;
|
|
}
|
|
max_splits = std::min({max_splits, num_SMs});
|
|
float max_efficiency = 0.f;
|
|
std::vector<float> efficiency;
|
|
efficiency.reserve(max_splits);
|
|
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
|
{
|
|
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
|
float eff = n_waves / ceil(n_waves);
|
|
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
|
if(eff > max_efficiency)
|
|
{
|
|
max_efficiency = eff;
|
|
}
|
|
efficiency.push_back(eff);
|
|
}
|
|
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
|
{
|
|
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
|
{
|
|
// printf("num_splits chosen = %d\n", num_splits);
|
|
return num_splits;
|
|
}
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
int override_num_splits_if_necessary(
|
|
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
|
{
|
|
(void)hdim_v;
|
|
int device;
|
|
auto status = hipGetDevice(&device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return num_splits;
|
|
}
|
|
|
|
hipDeviceProp_t props{};
|
|
status = hipGetDeviceProperties(&props, device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return num_splits;
|
|
}
|
|
|
|
// tile size should match the generate.py
|
|
const int kM0 = 64;
|
|
|
|
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
|
|
|
if(num_splits < 1 && p_drop == 0.0f)
|
|
{
|
|
return num_splits_heuristic(
|
|
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
|
}
|
|
|
|
return num_splits;
|
|
}
|
|
|
|
template <typename DataTypeConfig>
|
|
fwd_result fmha_fwd_run(mode_enum mode,
|
|
ck_tile::index_t batch,
|
|
ck_tile::index_t nhead,
|
|
ck_tile::index_t nhead_k,
|
|
std::vector<ck_tile::index_t> seqlen_qs,
|
|
std::vector<ck_tile::index_t> seqlen_ks,
|
|
ck_tile::index_t hdim_q,
|
|
ck_tile::index_t hdim_v,
|
|
ck_tile::index_t seqlen_knew,
|
|
std::vector<ck_tile::index_t> seqlen_qpads,
|
|
std::vector<ck_tile::index_t> seqlen_kpads,
|
|
std::vector<ck_tile::index_t> q_eff_lens_per_batch,
|
|
std::vector<ck_tile::index_t> kv_eff_lens_per_batch,
|
|
ck_tile::index_t rotary_dim,
|
|
bool i_perm,
|
|
bool o_perm,
|
|
float scale_s,
|
|
float logits_soft_cap,
|
|
bool is_v_rowmajor,
|
|
bool lse,
|
|
ck_tile::index_t page_block_size,
|
|
bool use_cache_batch_idx,
|
|
std::string bias_str,
|
|
float p_drop,
|
|
uint64_t drop_seed,
|
|
uint64_t drop_offset,
|
|
bool drop_prefs,
|
|
std::string mask_str,
|
|
bool squant,
|
|
bool is_rotary_interleaved,
|
|
ck_tile::index_t num_splits,
|
|
std::string init_method,
|
|
uint32_t seed,
|
|
int do_validation,
|
|
const ck_tile::stream_config& stream_config,
|
|
std::optional<std::string> json = std::nullopt)
|
|
{
|
|
const std::string data_type = []() {
|
|
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
|
return "fp32";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
|
|
return "fp16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf16>)
|
|
return "bf16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
|
return "fp8";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf8>)
|
|
return "bf8";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
|
return "fp8bf16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
|
|
return "fp8fp32";
|
|
else
|
|
static_assert(false);
|
|
}();
|
|
|
|
if(nhead_k < 0)
|
|
nhead_k = nhead;
|
|
if(nhead % nhead_k != 0)
|
|
{
|
|
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}());
|
|
auto next_seed = [&random_engine]() { return static_cast<unsigned int>(random_engine()); };
|
|
|
|
if(hdim_v < 0)
|
|
hdim_v = hdim_q;
|
|
|
|
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(seqlen_knew != 0)
|
|
{
|
|
std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
|
|
<< std::endl;
|
|
seqlen_knew = 0;
|
|
}
|
|
#endif
|
|
if(seqlen_knew < 0)
|
|
{
|
|
seqlen_knew = randint<ck_tile::index_t>(1, seqlen_qs[0], random_engine);
|
|
}
|
|
|
|
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
|
|
{
|
|
if(0 < rotary_dim)
|
|
{
|
|
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
}
|
|
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
|
else if(0 < rotary_dim)
|
|
{
|
|
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
|
|
<< std::endl;
|
|
rotary_dim = 0;
|
|
}
|
|
#endif
|
|
// to use fmha_fwd_appendkv(), make sure it's in batch mode
|
|
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
|
if(need_append_kvcache && mode == mode_enum::group)
|
|
{
|
|
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
|
|
mode = mode_enum::batch;
|
|
}
|
|
if(!(rotary_dim <= hdim_q))
|
|
{
|
|
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
else if(!(rotary_dim % 16 == 0))
|
|
{
|
|
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
|
|
CK_TILE_FMHA_FWD_PAGEDKV_API))
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
|
<< std::endl;
|
|
page_block_size = 0;
|
|
}
|
|
#endif
|
|
if(!(page_block_size % 128 == 0))
|
|
{
|
|
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
|
<< std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API)
|
|
if(use_cache_batch_idx)
|
|
{
|
|
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
#else
|
|
if(use_cache_batch_idx)
|
|
{
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
|
"'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
else if(mode == mode_enum::group)
|
|
{
|
|
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
|
|
"'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
}
|
|
#endif
|
|
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
|
|
|
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
|
|
const bool has_group_padding =
|
|
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
|
|
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
|
|
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
|
|
!kv_eff_lens_per_batch.empty()));
|
|
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
|
|
const bool using_pagedkv = (0 < page_block_size);
|
|
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
|
|
if((using_appendkv || using_pagedkv || using_splitkv) &&
|
|
(has_group_padding || has_batch_efflens))
|
|
{
|
|
std::cerr << "Padding (physical or effective lengths) is not supported with "
|
|
"appendkv/splitkv/pagedkv pipelines"
|
|
<< std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
|
generate_missing_seqlens(mode,
|
|
batch,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_kpads,
|
|
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
|
need_append_kvcache,
|
|
random_engine);
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb])
|
|
{
|
|
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
}
|
|
// compute kvcache seqlen_k (before appending knew/vnew)
|
|
auto cache_seqlen_ks = seqlen_ks;
|
|
std::transform(cache_seqlen_ks.begin(),
|
|
cache_seqlen_ks.end(),
|
|
cache_seqlen_ks.begin(),
|
|
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
|
|
|
|
#if 0
|
|
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
|
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
|
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
|
|
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
|
|
#endif
|
|
|
|
if(scale_s == .0f)
|
|
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
|
|
|
bias_info bias = bias_info::decode(bias_str);
|
|
|
|
mask_info mask =
|
|
mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
|
|
|
|
if(p_drop < 0.0f || p_drop > 1.0f)
|
|
{
|
|
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
bool s_randval = false;
|
|
if(p_drop > 0.0f && do_validation)
|
|
{
|
|
s_randval = true;
|
|
}
|
|
|
|
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
|
if(num_splits != 1)
|
|
{
|
|
std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
|
|
num_splits = 1;
|
|
}
|
|
#endif
|
|
|
|
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
|
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
|
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
|
|
|
// Optional padded Q seqstarts (group-mode only)
|
|
std::vector<int32_t> seqstart_q_with_padding_host;
|
|
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
|
|
{
|
|
if(seqlen_qpads.size() < static_cast<size_t>(batch))
|
|
{
|
|
seqlen_qpads.resize(batch, seqlen_qpads.back());
|
|
}
|
|
if(seqlen_qpads.size() == static_cast<size_t>(batch))
|
|
{
|
|
seqstart_q_with_padding_host = to_seqstarts(
|
|
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
|
|
}
|
|
}
|
|
|
|
// Optional batch-mode cumulative seqlen overrides
|
|
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
|
if(mode == mode_enum::batch)
|
|
{
|
|
auto calculate_cumulative = [&](std::vector<ck_tile::index_t>& per_batch_vec,
|
|
std::vector<ck_tile::index_t>& cum_vec) {
|
|
if(!per_batch_vec.empty() && per_batch_vec[0] != -1)
|
|
{
|
|
if(per_batch_vec.size() < static_cast<size_t>(batch))
|
|
{
|
|
per_batch_vec.resize(batch, per_batch_vec.back());
|
|
}
|
|
cum_vec.resize(batch + 1);
|
|
cum_vec[0] = 0;
|
|
for(int i = 0; i < batch; ++i)
|
|
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
|
}
|
|
};
|
|
|
|
calculate_cumulative(q_eff_lens_per_batch, cuq_cum);
|
|
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
|
|
}
|
|
|
|
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
|
|
|
using QDataType = typename TypeConfig::QDataType;
|
|
using KDataType = typename TypeConfig::KDataType;
|
|
using VDataType = typename TypeConfig::VDataType;
|
|
using BiasDataType = typename TypeConfig::BiasDataType;
|
|
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
|
using LSEDataType = typename TypeConfig::LSEDataType;
|
|
using SaccDataType = typename TypeConfig::SaccDataType;
|
|
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
|
using PDataType = typename TypeConfig::PDataType;
|
|
using OaccDataType = typename TypeConfig::OaccDataType;
|
|
using ODataType = typename TypeConfig::ODataType;
|
|
|
|
// accumulation numbers for performance evaluation
|
|
std::size_t flop = 0, num_byte = 0;
|
|
auto max_seqlen_q =
|
|
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
|
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
|
{
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
|
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
|
|
|
if(max_seqlen_q < real_seqlen_q)
|
|
{
|
|
max_seqlen_q = real_seqlen_q;
|
|
}
|
|
|
|
if(max_seqlen_k < real_seqlen_k)
|
|
{
|
|
max_seqlen_k = real_seqlen_k;
|
|
}
|
|
|
|
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
|
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
|
|
|
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
|
sizeof(ODataType) * real_seqlen_q * hdim_v);
|
|
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
|
|
sizeof(VDataType) * hdim_v * real_seqlen_k);
|
|
}
|
|
}
|
|
|
|
const ck_tile::index_t max_num_page_blocks =
|
|
(0 < page_block_size
|
|
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
|
: 0);
|
|
|
|
// legalize num_splits according to other options
|
|
if(num_splits < 1)
|
|
{
|
|
num_splits = override_num_splits_if_necessary(
|
|
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
|
|
}
|
|
if(128 < num_splits)
|
|
{
|
|
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < p_drop && (1 < num_splits || use_kvcache))
|
|
{
|
|
std::cerr << "dropout is not supported by split-kv kernels. ignoring the 'p_drop' option"
|
|
<< std::endl;
|
|
p_drop = 0.0f;
|
|
}
|
|
#endif
|
|
|
|
static const auto get_lengths = [](bool permute,
|
|
ck_tile::index_t b /*batch*/,
|
|
ck_tile::index_t h /*nhead*/,
|
|
ck_tile::index_t s /*seqlen*/,
|
|
ck_tile::index_t d /*hdim*/) {
|
|
if(permute)
|
|
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
|
else
|
|
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
|
};
|
|
|
|
// host memory for storing all the tensor elements
|
|
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
|
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
|
|
const ck_tile::index_t shape_seqlen_q_lse =
|
|
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
|
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
|
const ck_tile::index_t shape_seqlen_q =
|
|
(mode == mode_enum::batch
|
|
? seqlen_qs[0]
|
|
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
|
|
: seqstart_q_with_padding_host.back()));
|
|
const ck_tile::index_t shape_seqlen_k =
|
|
(mode == mode_enum::batch ? seqlen_ks[0]
|
|
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
|
: seqstart_k_with_padding_host.back()));
|
|
|
|
ck_tile::HostTensor<QDataType> q_host(
|
|
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
|
ck_tile::HostTensor<KDataType> k_host(
|
|
0 < page_block_size
|
|
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
|
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
|
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
|
ck_tile::HostTensor<KDataType> knew_host(
|
|
0 < seqlen_knew
|
|
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
ck_tile::HostTensor<VDataType> v_host(
|
|
0 < page_block_size
|
|
? (is_v_rowmajor
|
|
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
|
|
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
|
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
|
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
|
ck_tile::HostTensor<VDataType> vnew_host(
|
|
0 < seqlen_knew
|
|
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
|
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
ck_tile::HostTensor<BiasDataType> bias_host(
|
|
bias.type == bias_enum::elementwise_bias
|
|
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
|
|
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
|
|
bias.type == bias_enum::alibi
|
|
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
|
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
|
: std::array<ck_tile::index_t, 2>{1, 1});
|
|
|
|
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
|
|
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, next_seed());
|
|
|
|
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
|
1 < num_splits || use_kvcache
|
|
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
|
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
|
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
|
|
nhead,
|
|
num_splits,
|
|
shape_seqlen_q,
|
|
hdim_v}
|
|
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
|
|
|
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
|
// group mode of lse data layout is [nhead, total_seqlen_q]
|
|
ck_tile::HostTensor<LSEDataType> lse_host(
|
|
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
|
|
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
|
|
|
ck_tile::HostTensor<ODataType> o_host(
|
|
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
|
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
|
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
|
|
|
ck_tile::HostTensor<int32_t> block_table_host(
|
|
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
|
|
: std::array<ck_tile::index_t, 2>{1, 1});
|
|
|
|
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
|
? std::array<ck_tile::index_t, 1>{batch}
|
|
: std::array<ck_tile::index_t, 1>{1});
|
|
float max_o = 5.0;
|
|
if(init_method == "ui" || init_method == "0")
|
|
{
|
|
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
|
bias_host);
|
|
}
|
|
else if(init_method == "ni")
|
|
{
|
|
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
|
bias_host);
|
|
}
|
|
else if(init_method == "uf" || init_method == "1")
|
|
{
|
|
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, next_seed()}(q_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, next_seed()}(k_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, next_seed()}(knew_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, next_seed()}(v_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, next_seed()}(vnew_host);
|
|
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
|
|
}
|
|
else if(init_method == "nf")
|
|
{
|
|
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, next_seed()}(bias_host);
|
|
}
|
|
else if(init_method == "tf" || init_method == "2")
|
|
{
|
|
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
|
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
|
ck_tile::FillTrigValue<KDataType>{}(knew_host);
|
|
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
|
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
|
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
|
}
|
|
if(bias.type == bias_enum::alibi)
|
|
{
|
|
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
|
assert(slopes.size() == static_cast<std::size_t>(nhead));
|
|
if(bias.rank_info == 0)
|
|
{
|
|
// alibi in 1*h
|
|
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
|
|
}
|
|
else
|
|
{
|
|
// alibi in b*h
|
|
for(auto i_b = 0; i_b < batch; i_b++)
|
|
{
|
|
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
|
|
}
|
|
}
|
|
}
|
|
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
|
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
|
|
|
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
|
|
? 0
|
|
: seqstart_q_with_padding_host.size() *
|
|
sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_k_padded_buf(
|
|
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
|
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
|
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
|
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
|
|
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
|
0 <= seqlen_kpads[0]
|
|
? seqlen_ks.size() * sizeof(int32_t)
|
|
: 0);
|
|
ck_tile::DeviceMem cache_seqlen_k_buf(
|
|
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
|
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
|
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
|
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
|
|
|
float scale_p = 1.f;
|
|
float scale_o = 1.f;
|
|
if(squant)
|
|
{
|
|
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
|
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
|
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
|
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
|
// Q tensor
|
|
{
|
|
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::min());
|
|
q_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
if(val > max_value)
|
|
max_value = val;
|
|
});
|
|
|
|
float scale = q_dtype_max / max_value;
|
|
|
|
q_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
self(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
|
});
|
|
scale_s = scale_s / scale;
|
|
}
|
|
|
|
// K tensor
|
|
{
|
|
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::min());
|
|
k_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
if(val > max_value)
|
|
max_value = val;
|
|
});
|
|
float scale = k_dtype_max / max_value;
|
|
k_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
self(idx) = ck_tile::type_convert<KDataType>(val * scale);
|
|
});
|
|
scale_s = scale_s / scale;
|
|
}
|
|
|
|
// V tensor
|
|
{
|
|
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::min());
|
|
v_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
if(val > max_value)
|
|
max_value = val;
|
|
});
|
|
|
|
float scale = k_dtype_max / max_value;
|
|
v_host.ForEach([&](auto& self, auto idx) {
|
|
float val = ck_tile::type_convert<float>(self(idx));
|
|
self(idx) = ck_tile::type_convert<VDataType>(val * scale);
|
|
});
|
|
|
|
scale_o = (1.0 / p_dtype_max) / scale;
|
|
}
|
|
|
|
scale_p = p_dtype_max;
|
|
|
|
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
|
{
|
|
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
|
scale_o = scale_o * o_dtype_max / max_o;
|
|
}
|
|
}
|
|
|
|
q_buf.ToDevice(q_host.data());
|
|
k_buf.ToDevice(k_host.data());
|
|
v_buf.ToDevice(v_host.data());
|
|
knew_buf.ToDevice(knew_host.data());
|
|
vnew_buf.ToDevice(vnew_host.data());
|
|
bias_buf.ToDevice(bias_host.data());
|
|
seqstart_q.ToDevice(seqstart_q_host.data());
|
|
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
|
seqstart_k.ToDevice(seqstart_k_host.data());
|
|
seqstart_q_padded_buf.ToDevice(
|
|
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
|
|
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
|
|
: seqstart_k_with_padding_host.data());
|
|
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
|
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
|
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
|
? seqlen_ks.data()
|
|
: nullptr);
|
|
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
|
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
|
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
|
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
|
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
|
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
|
block_table_buf.ToDevice(block_table_host.data());
|
|
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
|
|
|
// clang-format off
|
|
auto layout_str = [&](bool permute){
|
|
if(permute) return std::string("bhsd");
|
|
else return std::string("bshd");
|
|
};
|
|
auto io_layout = [&](bool iperm_, bool operm_) {
|
|
if(iperm_ == operm_) return layout_str(iperm_);
|
|
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
|
|
};
|
|
// clang-format on
|
|
|
|
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
|
|
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
|
<< "/" << seqlen_ks[0]
|
|
<< (seqlen_kpads[0] < 0 ? ""
|
|
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
|
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
|
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
|
<< ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c");
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(0 < rotary_dim)
|
|
{
|
|
std::cout << ", rotary_dim:" << rotary_dim << "("
|
|
<< (is_rotary_interleaved ? "inter" : "half") << ")";
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(1 < num_splits)
|
|
{
|
|
std::cout << ", num_splits:" << num_splits;
|
|
}
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cout << ", page_block_size:" << page_block_size;
|
|
}
|
|
if(use_cache_batch_idx)
|
|
{
|
|
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
|
|
}
|
|
#endif
|
|
// Padding / effective length diagnostic logging
|
|
auto print_vec = [&](const char* label, const std::vector<int>& v) {
|
|
if(v.empty())
|
|
return;
|
|
std::cout << ", " << label << ":[";
|
|
for(std::size_t i = 0; i < v.size(); ++i)
|
|
{
|
|
if(i)
|
|
std::cout << ",";
|
|
std::cout << v[i];
|
|
}
|
|
std::cout << "]";
|
|
};
|
|
|
|
if(has_group_padding)
|
|
{
|
|
bool has_qpad = !seqstart_q_with_padding_host.empty();
|
|
bool has_kpad = (seqlen_kpads[0] >= 0);
|
|
if(has_qpad)
|
|
{
|
|
print_vec("q_logical", seqlen_qs);
|
|
print_vec("q_padded", seqlen_qpads);
|
|
}
|
|
if(has_kpad)
|
|
{
|
|
print_vec("k_logical", seqlen_ks);
|
|
print_vec("k_padded", seqlen_kpads);
|
|
}
|
|
}
|
|
else if(has_batch_efflens)
|
|
{
|
|
// derive effective lengths from cumulative arrays if present
|
|
if(!cuq_cum.empty())
|
|
{
|
|
std::vector<int> eff_q(batch);
|
|
for(int b_i = 0; b_i < batch; ++b_i)
|
|
eff_q[b_i] = static_cast<int>(cuq_cum[b_i + 1] - cuq_cum[b_i]);
|
|
print_vec("q_eff", eff_q);
|
|
}
|
|
if(!cukv_cum.empty())
|
|
{
|
|
std::vector<int> eff_kv(batch);
|
|
for(int b_i = 0; b_i < batch; ++b_i)
|
|
eff_kv[b_i] = static_cast<int>(cukv_cum[b_i + 1] - cukv_cum[b_i]);
|
|
print_vec("kv_eff", eff_kv);
|
|
}
|
|
}
|
|
|
|
std::cout << std::flush;
|
|
|
|
const auto init_traits = [&](auto& traits) {
|
|
traits.hdim_q = hdim_q;
|
|
traits.hdim_v = hdim_v;
|
|
traits.data_type = data_type;
|
|
traits.is_v_rowmajor = is_v_rowmajor;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
|
|
: rope_enum::half_rotated)
|
|
: rope_enum::none);
|
|
}
|
|
else // fmha_fwd_traits or fmha_splitkv_traits
|
|
{
|
|
traits.is_group_mode = (mode == mode_enum::group);
|
|
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
|
traits.mask_type = mask.type;
|
|
traits.bias_type = bias.type;
|
|
traits.has_lse = lse;
|
|
traits.do_fp8_static_quant = squant;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.has_dropout = (p_drop > 0.0f);
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
|
|
std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.use_pagedkv = (0 < page_block_size);
|
|
}
|
|
}
|
|
};
|
|
|
|
const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
|
|
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
|
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
|
/// 'nhead_stride_bias' are 0.
|
|
// setup stride_* arguments
|
|
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
|
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
|
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
|
|
const ck_tile::index_t stride_v = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? hdim_v : nhead_k * hdim_v;
|
|
else
|
|
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
|
|
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
|
}();
|
|
const ck_tile::index_t stride_vnew = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? hdim_v : nhead_k * hdim_v;
|
|
else
|
|
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
|
}();
|
|
const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k);
|
|
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
|
const ck_tile::index_t stride_o_acc = (hdim_v);
|
|
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
|
// setup nhead_stride_* arguments
|
|
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
|
const ck_tile::index_t nhead_stride_k =
|
|
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
|
|
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
|
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
|
|
const ck_tile::index_t nhead_stride_v = [&]() {
|
|
if(is_v_rowmajor)
|
|
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
|
|
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
|
else
|
|
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
|
|
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
|
}();
|
|
const ck_tile::index_t nhead_stride_vnew = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? seqlen_knew * hdim_v : hdim_v;
|
|
else
|
|
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
|
}();
|
|
const ck_tile::index_t nhead_stride_bias =
|
|
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
|
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
|
|
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
|
|
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
|
// setup batch_stride_* arguments
|
|
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
|
const ck_tile::index_t batch_stride_k =
|
|
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
|
: (nhead_k * shape_seqlen_k * hdim_q));
|
|
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
|
|
const ck_tile::index_t batch_stride_v =
|
|
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
|
: (nhead_k * hdim_v * shape_seqlen_k));
|
|
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
|
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
|
|
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
|
|
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
|
// setup split_stride_* arguments (only used in split-kv kernel)
|
|
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
|
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
|
|
|
args.q_ptr = q_buf.GetDeviceBuffer();
|
|
args.k_ptr = k_buf.GetDeviceBuffer();
|
|
args.v_ptr = v_buf.GetDeviceBuffer();
|
|
|
|
args.batch = batch;
|
|
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
|
args.hdim_q = hdim_q;
|
|
args.hdim_v = hdim_v;
|
|
args.nhead_q = nhead;
|
|
args.nhead_k = nhead_k;
|
|
|
|
args.stride_q = stride_q;
|
|
args.stride_k = stride_k;
|
|
args.stride_v = stride_v;
|
|
args.nhead_stride_q = nhead_stride_q;
|
|
args.nhead_stride_k = nhead_stride_k;
|
|
args.nhead_stride_v = nhead_stride_v;
|
|
args.batch_stride_q = batch_stride_q;
|
|
args.batch_stride_k = batch_stride_k;
|
|
args.batch_stride_v = batch_stride_v;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.knew_ptr = knew_buf.GetDeviceBuffer();
|
|
args.vnew_ptr = vnew_buf.GetDeviceBuffer();
|
|
args.seqlen_knew = seqlen_knew;
|
|
|
|
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
|
|
|
|
args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr);
|
|
args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr);
|
|
args.rotary_dim = rotary_dim;
|
|
args.has_mask = (mask.type != mask_enum::no_mask);
|
|
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
|
|
args.stride_knew = stride_knew;
|
|
args.stride_vnew = stride_vnew;
|
|
args.nhead_stride_knew = nhead_stride_knew;
|
|
args.nhead_stride_vnew = nhead_stride_vnew;
|
|
args.batch_stride_knew = batch_stride_knew;
|
|
args.batch_stride_vnew = batch_stride_vnew;
|
|
}
|
|
else // fmha_fwd_args or fmha_fwd_splitkv_args
|
|
{
|
|
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
|
: bias_buf.GetDeviceBuffer();
|
|
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
|
args.o_ptr = o_buf.GetDeviceBuffer();
|
|
|
|
args.seqstart_q_ptr =
|
|
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
|
args.seqstart_k_ptr =
|
|
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
|
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
|
? seqlen_k_buf.GetDeviceBuffer()
|
|
: nullptr);
|
|
|
|
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
|
args.max_seqlen_q = max_seqlen_q;
|
|
|
|
args.scale_s = scale_s;
|
|
args.scale_p = scale_p;
|
|
args.scale_o = scale_o;
|
|
|
|
args.logits_soft_cap = logits_soft_cap;
|
|
|
|
args.stride_bias =
|
|
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
|
|
args.stride_o = stride_o;
|
|
args.nhead_stride_bias = nhead_stride_bias;
|
|
args.nhead_stride_lse = nhead_stride_lse;
|
|
args.nhead_stride_o = nhead_stride_o;
|
|
args.batch_stride_bias = batch_stride_bias;
|
|
args.batch_stride_lse = batch_stride_lse;
|
|
args.batch_stride_o = batch_stride_o;
|
|
|
|
args.window_size_left = mask.left;
|
|
args.window_size_right = mask.right;
|
|
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
|
|
|
args.stride_randval = stride_randval;
|
|
args.nhead_stride_randval = nhead_stride_randval;
|
|
args.batch_stride_randval = batch_stride_randval;
|
|
|
|
args.p_drop = p_drop;
|
|
args.s_randval = s_randval;
|
|
if(drop_prefs)
|
|
{
|
|
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
|
drop_offset_buf.GetDeviceBuffer());
|
|
}
|
|
else
|
|
{
|
|
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
|
}
|
|
|
|
// Group-mode: optional physical padded starts for Q/K
|
|
if(mode == mode_enum::group)
|
|
{
|
|
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
|
|
? nullptr
|
|
: seqstart_q_padded_buf.GetDeviceBuffer());
|
|
args.seqstart_padded_k_ptr =
|
|
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
|
|
}
|
|
|
|
// Batch-mode: optional cumulative effective seqlen overrides
|
|
if(mode == mode_enum::batch)
|
|
{
|
|
args.cu_seqlen_q_ptr = cuq_cum.empty()
|
|
? nullptr
|
|
: reinterpret_cast<const ck_tile::index_t*>(
|
|
cu_seqlen_q_buf.GetDeviceBuffer());
|
|
args.cu_seqlen_kv_ptr = cukv_cum.empty()
|
|
? nullptr
|
|
: reinterpret_cast<const ck_tile::index_t*>(
|
|
cu_seqlen_kv_buf.GetDeviceBuffer());
|
|
}
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
|
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
|
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
args.is_gappy = false; // use 'false' for flash-attention integration
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
|
|
args.num_splits = num_splits;
|
|
|
|
args.stride_o_acc = stride_o_acc;
|
|
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
|
|
args.nhead_stride_o_acc = nhead_stride_o_acc;
|
|
args.batch_stride_lse_acc = batch_stride_lse_acc;
|
|
args.batch_stride_o_acc = batch_stride_o_acc;
|
|
args.split_stride_lse_acc = split_stride_lse_acc;
|
|
args.split_stride_o_acc = split_stride_o_acc;
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
args.is_gappy = false; // use 'false' for flash-attention integration
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
}
|
|
}
|
|
};
|
|
|
|
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(need_append_kvcache)
|
|
{
|
|
fmha_fwd_appendkv_traits fwd_appendkv_traits;
|
|
init_traits(fwd_appendkv_traits);
|
|
|
|
fmha_fwd_appendkv_args fwd_appendkv_args;
|
|
init_args(fwd_appendkv_args);
|
|
|
|
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
|
|
}
|
|
#endif
|
|
return 0.0f;
|
|
};
|
|
const float appendkv_ave_time = run_appendkv(stream_config);
|
|
if(appendkv_ave_time < 0.0f)
|
|
{
|
|
std::cout << ", not supported yet" << std::flush << std::endl;
|
|
return fwd_result::no_instance;
|
|
}
|
|
|
|
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
|
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(1 == num_splits && use_kvcache)
|
|
{
|
|
fmha_fwd_pagedkv_traits fmha_pagedkv_traits;
|
|
init_traits(fmha_pagedkv_traits);
|
|
|
|
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
|
init_args(fmha_pagedkv_args);
|
|
|
|
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
// If there is no instance for these args, fallback to fmha_fwd_splitkv
|
|
if(ave_time >= 0.0f)
|
|
return ave_time;
|
|
#else
|
|
return ave_time;
|
|
#endif
|
|
}
|
|
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
if(1 < num_splits || use_kvcache)
|
|
{
|
|
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
|
init_traits(fmha_splitkv_traits);
|
|
|
|
fmha_fwd_splitkv_args fmha_splitkv_args;
|
|
init_args(fmha_splitkv_args);
|
|
|
|
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
|
|
}
|
|
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
|
fmha_fwd_traits fmha_traits;
|
|
init_traits(fmha_traits);
|
|
|
|
fmha_fwd_args fmha_args;
|
|
init_args(fmha_args);
|
|
|
|
return fmha_fwd(fmha_traits, fmha_args, sc);
|
|
};
|
|
const float fwd_ave_time = run_fwd(stream_config);
|
|
if(fwd_ave_time < 0.0f)
|
|
{
|
|
std::cout << ", not supported yet" << std::flush << std::endl;
|
|
return fwd_result::no_instance;
|
|
}
|
|
|
|
const float ave_time = appendkv_ave_time + fwd_ave_time;
|
|
const float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
|
const float gb_per_sec = num_byte / 1.E6 / ave_time;
|
|
if(stream_config.time_kernel_)
|
|
{
|
|
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
|
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
|
<< gb_per_sec << " GB/s" << std::flush;
|
|
}
|
|
|
|
bool pass = true;
|
|
if(do_validation == 0)
|
|
{
|
|
std::cout << std::flush << std::endl;
|
|
}
|
|
else if(do_validation == 2)
|
|
{
|
|
// NOTE: use gpu to do validation
|
|
ck_tile::naive_attention_fwd_traits naive_t;
|
|
naive_t.q_type = data_type;
|
|
naive_t.k_type = data_type;
|
|
naive_t.v_type = data_type;
|
|
naive_t.o_type = data_type;
|
|
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.variation = 0; // TODO?
|
|
naive_t.quant_algo = 0;
|
|
|
|
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
|
|
|
|
ck_tile::naive_attention_fwd_args naive_a;
|
|
naive_a.q_ptr = q_buf.GetDeviceBuffer();
|
|
naive_a.k_ptr = k_buf.GetDeviceBuffer();
|
|
naive_a.v_ptr = v_buf.GetDeviceBuffer();
|
|
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
|
|
naive_a.scale_s = scale_s;
|
|
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
|
|
naive_a.page_table_ptr =
|
|
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
|
|
naive_a.hdim = hdim_q;
|
|
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
|
|
naive_a.batch_q = batch;
|
|
naive_a.batch_kv = batch;
|
|
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
|
|
naive_a.seqlen_q = seqlen_qs[0];
|
|
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
|
|
naive_a.nhead_q = nhead;
|
|
naive_a.nhead_kv = nhead_k;
|
|
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
|
|
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
|
|
|
|
ck_tile::stream_config naive_s{};
|
|
|
|
naive_attention_fwd(naive_t, naive_a, naive_s);
|
|
|
|
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
|
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
|
|
|
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
|
pass = ck_tile::check_err(
|
|
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
|
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
|
}
|
|
else
|
|
{
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
|
|
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
|
|
if(0 < rotary_dim && stream_config.time_kernel_)
|
|
{
|
|
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
|
|
q_buf.ToDevice(q_host.data());
|
|
run_appendkv(stream_config2);
|
|
run_fwd(stream_config2);
|
|
}
|
|
#endif
|
|
o_buf.FromDevice(o_host.data());
|
|
lse_buf.FromDevice(lse_host.data());
|
|
randval_buf.FromDevice(randval_host.data());
|
|
|
|
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
|
|
|
|
auto p_compute_element_func = [&]() {
|
|
if constexpr(supports_squant)
|
|
return ck_tile::scales{scale_p};
|
|
else
|
|
return ck_tile::identity{};
|
|
}();
|
|
|
|
auto oacc_element_func = [&]() {
|
|
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_squant)
|
|
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
|
ck_tile::scales{scale_o});
|
|
else if constexpr(supports_squant)
|
|
return ck_tile::scales{scale_o};
|
|
else
|
|
return ck_tile::identity{};
|
|
}();
|
|
|
|
float p_undrop = 1.0 - p_drop;
|
|
uint8_t p_undrop_in_uint8_t =
|
|
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
|
float rp_undrop = 1.0 / p_undrop;
|
|
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
|
ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
|
if(mode == mode_enum::batch)
|
|
{
|
|
if(!cuq_cum.empty())
|
|
{
|
|
real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb];
|
|
}
|
|
if(!cukv_cum.empty())
|
|
{
|
|
real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb];
|
|
}
|
|
}
|
|
|
|
// adjust matrix index according to the mode
|
|
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
|
const ck_tile::index_t cache_b_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
|
const ck_tile::index_t query_offset =
|
|
(mode == mode_enum::batch
|
|
? 0
|
|
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
|
|
: seqstart_q_with_padding_host[wb]));
|
|
const ck_tile::index_t key_offset =
|
|
(mode == mode_enum::batch
|
|
? 0
|
|
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
|
|
: seqstart_k_with_padding_host[wb]));
|
|
|
|
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
|
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
|
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
|
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
|
|
|
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
|
|
|
ck_tile::index_t nr = nhead / nhead_k;
|
|
|
|
// clang-format off
|
|
// permute
|
|
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
|
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
|
// clang-format on
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// optionally apply RoPE to the q_host_ref
|
|
if(0 < rotary_dim)
|
|
{
|
|
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
|
|
|
auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(
|
|
rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
|
|
|
|
ck_tile::reference_batched_rotary_position_embedding(
|
|
q_host_ref,
|
|
rotary_cos_slice,
|
|
rotary_sin_slice,
|
|
is_rotary_interleaved,
|
|
q_host_ref_ro,
|
|
/*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask);
|
|
|
|
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < page_block_size)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); });
|
|
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
// clang-format off
|
|
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
|
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// copy Knew to the end of K
|
|
if(0 < seqlen_knew)
|
|
{
|
|
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
|
// clang-format off
|
|
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
|
|
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
|
|
// optionally apply RoPE to the knew_host_ref
|
|
auto* real_knew_host_ref = &knew_host_ref;
|
|
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
|
|
if(0 < rotary_dim)
|
|
{
|
|
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
|
|
|
auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(
|
|
rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
|
|
|
|
ck_tile::reference_batched_rotary_position_embedding(knew_host_ref,
|
|
rotary_cos_slice,
|
|
rotary_sin_slice,
|
|
is_rotary_interleaved,
|
|
knew_host_ref_ro.value());
|
|
|
|
real_knew_host_ref = &knew_host_ref_ro.value();
|
|
}
|
|
|
|
(*real_knew_host_ref).ForEach([&](auto& self, auto i) {
|
|
k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i);
|
|
});
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < page_block_size)
|
|
{
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); });
|
|
// clang-format on
|
|
}
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
|
|
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
|
|
// clang-format on
|
|
}
|
|
}
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// copy Vnew to the end of V
|
|
if(0 < seqlen_knew)
|
|
{
|
|
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
|
|
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
|
|
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
|
|
vnew_host_ref.ForEach([&](auto& self, auto i) {
|
|
v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i);
|
|
});
|
|
}
|
|
#endif
|
|
|
|
// reference
|
|
ck_tile::
|
|
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
|
q_host_ref,
|
|
k_host_ref,
|
|
s_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
ck_tile::scales(scale_s));
|
|
|
|
if(0.f < logits_soft_cap)
|
|
{
|
|
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
|
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
|
return ck_tile::type_convert<SaccDataType>(
|
|
logits_soft_cap *
|
|
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
|
});
|
|
}
|
|
|
|
if(bias.type == bias_enum::elementwise_bias)
|
|
{
|
|
// elementwise bias
|
|
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
|
// clang-format off
|
|
if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
|
else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
|
// clang-format on
|
|
|
|
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
|
// real_seqlen_k]
|
|
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
|
BiasDataType,
|
|
SMPLComputeDataType,
|
|
SMPLComputeDataType>(
|
|
s_host_ref, bias_host_ref, s_host_ref);
|
|
}
|
|
else if(bias.type == bias_enum::alibi)
|
|
{
|
|
// alibi construct elementwise bias to verify
|
|
auto alibi_host = [&]() {
|
|
if(mask.type != mask_enum::no_mask)
|
|
{
|
|
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
|
|
0,
|
|
mask.left,
|
|
mask.right,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
|
}
|
|
else
|
|
{
|
|
return ck_tile::Alibi<SaccDataType, true>{
|
|
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
|
}
|
|
}();
|
|
|
|
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
|
for(auto i_h = 0; i_h < nhead; i_h++)
|
|
{
|
|
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
|
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL
|
|
? current_slope
|
|
: -current_slope;
|
|
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
|
{
|
|
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
|
{
|
|
SaccDataType pixel = 0;
|
|
alibi_host.update(pixel, i_r, i_c);
|
|
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
|
}
|
|
}
|
|
}
|
|
// [nhead, real_seqlen_q, real_seqlen_k]
|
|
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
|
SaccDataType,
|
|
SMPLComputeDataType,
|
|
SMPLComputeDataType>(
|
|
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
|
}
|
|
|
|
if(mask.type == mask_enum::no_mask)
|
|
{
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
|
}
|
|
else if(mask.type == mask_enum::window_generic)
|
|
{
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
|
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
|
}
|
|
else
|
|
{
|
|
// if left window size is negative, means causal
|
|
// else means generic (for current batch)
|
|
if(mask.left < 0)
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
|
mask.left,
|
|
mask.right,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
mask.type == mask_enum::mask_top_left));
|
|
else
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
|
mask.left,
|
|
mask.right,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
mask.type == mask_enum::mask_top_left));
|
|
}
|
|
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
|
if(lse)
|
|
{
|
|
ck_tile::
|
|
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
|
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
|
}
|
|
else
|
|
{
|
|
ck_tile::
|
|
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
|
s_host_ref, p_host_ref, p_compute_element_func);
|
|
}
|
|
|
|
if(p_drop > 0)
|
|
{
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::reference_batched_dropout_randval(
|
|
randval_host_ref, wb, drop_seed, drop_offset);
|
|
ck_tile::reference_batched_dropout(
|
|
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
|
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
randval_host_result.ForEach([&](auto& self, const auto& idx) {
|
|
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
|
|
});
|
|
masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) {
|
|
// Ignore all masked values in validation check
|
|
if(std::isinf(self(idx)))
|
|
{
|
|
randval_host_ref(idx) = 0;
|
|
randval_host_result(idx) = 0;
|
|
}
|
|
});
|
|
bool cur_pass = ck_tile::check_err(randval_host_result,
|
|
randval_host_ref,
|
|
"DROPOUT RANDVAL Error: Incorrect results!");
|
|
pass &= cur_pass;
|
|
if(!cur_pass)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
|
|
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
|
p_host_ref,
|
|
v_host_ref,
|
|
o_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
oacc_element_func);
|
|
|
|
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
|
// clang-format off
|
|
// permute
|
|
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
|
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
|
// clang-format on
|
|
|
|
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
|
bool cur_pass = ck_tile::check_err(o_host_result,
|
|
o_host_ref,
|
|
std::string("OUT Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
pass &= cur_pass;
|
|
if(!cur_pass)
|
|
{
|
|
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
|
|
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
|
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
|
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
|
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
|
|
|
break;
|
|
}
|
|
|
|
if(lse)
|
|
{
|
|
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
|
const ck_tile::index_t query_offset_lse =
|
|
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
|
lse_host_result.ForEach([&](auto& self, auto idx) {
|
|
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
|
|
});
|
|
|
|
cur_pass = ck_tile::check_err(lse_host_result,
|
|
lse_host_ref,
|
|
"LSE Error: Incorrect results!",
|
|
rtol,
|
|
atol,
|
|
/* allow_infinity_ref = */ true);
|
|
|
|
pass &= cur_pass;
|
|
if(!cur_pass)
|
|
{
|
|
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
|
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
|
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
|
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
|
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
|
}
|
|
|
|
if(json)
|
|
{
|
|
dump_fmha_fwd_json_results(*json,
|
|
data_type,
|
|
mode == mode_enum::batch ? "batch" : "group",
|
|
io_layout(i_perm, o_perm),
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs[0],
|
|
seqlen_ks[0],
|
|
seqlen_kpads[0],
|
|
hdim_q,
|
|
hdim_v,
|
|
scale_s,
|
|
p_drop,
|
|
lse,
|
|
squant,
|
|
bias.type == bias_enum::elementwise_bias
|
|
? "elementwise_bias"
|
|
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
|
is_v_rowmajor ? "r" : "c",
|
|
pass,
|
|
ave_time,
|
|
tflops,
|
|
gb_per_sec);
|
|
}
|
|
|
|
return pass ? fwd_result::success : fwd_result::failure;
|
|
}
|