mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_TILE] Add LLC-aware FMHA head grouping and head-major scheduling on RDNA (#5018) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Long-sequence FMHA can become memory-bound when K/V working sets exceed Infinity Cache (LLC), causing repeated DRAM traffic across heads. This PR introduces LLC-aware launch ordering improvements for FMHA forward, and it is currently enabled only on gfx11 and gfx12. The approach is inspired by [`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217), adapted to CK’s kernel/runner structure and layout handling. In this context, `bshd` is the layout used in Flash-Attention, while `bhsd` is the default layout used by the CK Tile FMHA example. ## Technical Details This PR adds two complementary strategies: - For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware head grouping: - Estimate LLC size (env override, KFD sysfs, or arch default). - Compute group size from K/V bytes per head vs LLC target. - Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and related tensors). - For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit launch-order adjustment: - Keep a single kernel launch. - Reinterpret block linearization in `GetTileIndex` to make execution head-major, improving temporal locality of per-head K/V reuse. Additional integration updates: - Propagate `num_head_q_total` and `head_start` through FMHA args/kargs. - Use global head indexing for dropout RNG stream mapping so grouped launches keep deterministic/consistent dropout behavior. - Keep fallback behavior unchanged when grouping is not beneficial or disabled. ## Test Plan - `test_ck_tile_fmha` - `tile_example_fmha_fwd` ## Test Result - `test_ck_tile_fmha`: all tests passed. - `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201, and all of them show higher performance compared to the baseline. The improvement is consistent, and performance is well maintained even at long sequence lengths. ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} - TFLOPs by sequence length target: gfx1100 layout: bhsd SeqLen | Before | After | Speedup -- | -- | -- | -- 1024 | 56.27 | 61.48 | 1.09x 4096 | 67.10 | 72.27 | 1.08x 8192 | 65.99 | 71.64 | 1.09x 12288 | 61.60 | 76.61 | 1.24x 16384 | 58.99 | 75.74 | 1.28x 20480 | 57.32 | 74.42 | 1.30x 24576 | 56.89 | 74.25 | 1.31x 27280 | 18.93 | 24.48 | 1.29x - TFLOPs by sequence length target: gfx1201 layout: bshd SeqLen | Before | After | Speedup -- | -- | -- | -- 1024 | 66.79 | 65.90 | 0.99x 4096 | 85.90 | 86.80 | 1.01x 8192 | 77.06 | 90.29 | 1.17x 12288 | 58.36 | 88.98 | 1.52x 16384 | 52.12 | 88.88 | 1.71x 20480 | 48.11 | 88.42 | 1.84x 24576 | 47.12 | 89.07 | 1.89x 27280 | 49.05 | 50.31 | 1.03x ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2395 lines
110 KiB
C++
2395 lines
110 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/host.hpp"
|
|
#include "ck_tile/ref/naive_attention.hpp"
|
|
#include "fmha_fwd.hpp"
|
|
#include "fmha_fwd_head_grouping.hpp"
|
|
#include "utils.hpp"
|
|
#include "ck_tile/utility/json_dump.hpp"
|
|
|
|
#include <array>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <cmath>
|
|
#include <numeric>
|
|
#include <optional>
|
|
#include <ostream>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
|
|
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
|
|
#endif
|
|
|
|
enum class fwd_result
|
|
{
|
|
success,
|
|
failure,
|
|
invalid_args,
|
|
no_instance,
|
|
};
|
|
|
|
// different threshold for different dtype
|
|
template <typename DataTypeConfig>
|
|
auto get_elimit(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-3;
|
|
double atol = 1e-3;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp32>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-5;
|
|
double atol = 1e-5;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1e-2;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
|
{
|
|
using TypeConfig = FmhaFwdTypeConfig<FmhaFwdFp8>;
|
|
using ODataType = typename TypeConfig::ODataType;
|
|
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
|
double rtol = 0;
|
|
double atol = 16 * (o_dtype_max > 240 ? 2 : 1);
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1.8e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1.8e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdMxFp8>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1.8e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<FmhaFwdMxFp4>(std::string /*init_method*/)
|
|
{
|
|
double rtol = 1e-1;
|
|
double atol = 2.6e-1;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
|
{
|
|
// If we have enough to almost fill the SMs, then just use 1 split
|
|
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
|
{
|
|
return 1;
|
|
}
|
|
max_splits = std::min({max_splits, num_SMs});
|
|
float max_efficiency = 0.f;
|
|
std::vector<float> efficiency;
|
|
efficiency.reserve(max_splits);
|
|
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
|
{
|
|
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
|
float eff = n_waves / ceil(n_waves);
|
|
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
|
if(eff > max_efficiency)
|
|
{
|
|
max_efficiency = eff;
|
|
}
|
|
efficiency.push_back(eff);
|
|
}
|
|
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
|
{
|
|
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
|
{
|
|
// printf("num_splits chosen = %d\n", num_splits);
|
|
return num_splits;
|
|
}
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
int override_num_splits_if_necessary(
|
|
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
|
{
|
|
(void)hdim_v;
|
|
int device;
|
|
auto status = hipGetDevice(&device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return num_splits;
|
|
}
|
|
|
|
hipDeviceProp_t props{};
|
|
status = hipGetDeviceProperties(&props, device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return num_splits;
|
|
}
|
|
|
|
// tile size should match the generate.py
|
|
const int kM0 = 64;
|
|
|
|
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
|
|
|
if(num_splits < 1 && p_drop == 0.0f)
|
|
{
|
|
return num_splits_heuristic(
|
|
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
|
}
|
|
|
|
return num_splits;
|
|
}
|
|
|
|
template <typename SMPLComputeDataType>
|
|
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
|
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
|
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
|
ck_tile::index_t nhead,
|
|
ck_tile::index_t real_seqlen_q,
|
|
ck_tile::index_t real_seqlen_k)
|
|
{
|
|
for(auto i_h = 0; i_h < nhead; i_h++)
|
|
{
|
|
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
|
{
|
|
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
|
{
|
|
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
|
}
|
|
// Append sink token at the end of each row
|
|
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename TypeConfig, bool IsMx>
|
|
struct ScalesConfig
|
|
{
|
|
using QScaleDataType = float;
|
|
using KScaleDataType = float;
|
|
using VScaleDataType = float;
|
|
|
|
static constexpr ck_tile::index_t kQKScaleGranularity = 1;
|
|
static constexpr ck_tile::index_t kVScaleGranularity = 1;
|
|
};
|
|
|
|
template <typename TypeConfig>
|
|
struct ScalesConfig<TypeConfig, true>
|
|
{
|
|
using QScaleDataType = typename TypeConfig::QScaleDataType;
|
|
using KScaleDataType = typename TypeConfig::KScaleDataType;
|
|
using VScaleDataType = typename TypeConfig::VScaleDataType;
|
|
|
|
static constexpr ck_tile::index_t kQKScaleGranularity = TypeConfig::kQKScaleGranularity;
|
|
static constexpr ck_tile::index_t kVScaleGranularity = TypeConfig::kVScaleGranularity;
|
|
};
|
|
|
|
template <typename DataTypeConfig>
|
|
fwd_result fmha_fwd_run(mode_enum mode,
|
|
ck_tile::index_t batch,
|
|
ck_tile::index_t nhead,
|
|
ck_tile::index_t nhead_k,
|
|
std::vector<ck_tile::index_t> seqlen_qs,
|
|
std::vector<ck_tile::index_t> seqlen_ks,
|
|
ck_tile::index_t hdim_q,
|
|
ck_tile::index_t hdim_v,
|
|
ck_tile::index_t seqlen_knew,
|
|
std::vector<ck_tile::index_t> seqlen_qpads,
|
|
std::vector<ck_tile::index_t> seqlen_kpads,
|
|
std::vector<ck_tile::index_t> q_eff_lens_per_batch,
|
|
std::vector<ck_tile::index_t> kv_eff_lens_per_batch,
|
|
ck_tile::index_t rotary_dim,
|
|
bool i_perm,
|
|
bool o_perm,
|
|
float scale_s,
|
|
float logits_soft_cap,
|
|
bool is_v_rowmajor,
|
|
bool lse,
|
|
ck_tile::index_t page_block_size,
|
|
bool use_cache_batch_idx,
|
|
std::string bias_str,
|
|
float p_drop,
|
|
uint64_t drop_seed,
|
|
uint64_t drop_offset,
|
|
bool drop_prefs,
|
|
std::string mask_str,
|
|
std::string qscale_str,
|
|
bool is_rotary_interleaved,
|
|
ck_tile::index_t num_splits,
|
|
std::string init_method,
|
|
uint32_t seed,
|
|
int do_validation,
|
|
int init_sink_value,
|
|
const ck_tile::stream_config& stream_config,
|
|
std::optional<std::string> json = std::nullopt)
|
|
{
|
|
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
|
|
|
constexpr bool is_mx = ck_tile::is_any_of<DataTypeConfig, FmhaFwdMxFp8, FmhaFwdMxFp4>::value;
|
|
|
|
using QDataType = typename TypeConfig::QDataType;
|
|
using KDataType = typename TypeConfig::KDataType;
|
|
using VDataType = typename TypeConfig::VDataType;
|
|
using BiasDataType = typename TypeConfig::BiasDataType;
|
|
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
|
using LSEDataType = typename TypeConfig::LSEDataType;
|
|
using SaccDataType = typename TypeConfig::SaccDataType;
|
|
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
|
using PDataType = std::conditional_t<is_mx, float, typename TypeConfig::PDataType>;
|
|
using OaccDataType = typename TypeConfig::OaccDataType;
|
|
using ODataType = typename TypeConfig::ODataType;
|
|
|
|
using QScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::QScaleDataType;
|
|
using KScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::KScaleDataType;
|
|
using VScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::VScaleDataType;
|
|
|
|
constexpr ck_tile::index_t kQKScaleGranularity =
|
|
ScalesConfig<TypeConfig, is_mx>::kQKScaleGranularity;
|
|
constexpr ck_tile::index_t kVScaleGranularity =
|
|
ScalesConfig<TypeConfig, is_mx>::kVScaleGranularity;
|
|
|
|
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
|
|
// compute block size
|
|
constexpr ck_tile::index_t block_scale_size_q_ = 128;
|
|
constexpr ck_tile::index_t block_scale_size_kv_ = 128;
|
|
|
|
const std::string data_type = []() {
|
|
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
|
return "fp32";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
|
|
return "fp16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf16>)
|
|
return "bf16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
|
return "fp8";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf8>)
|
|
return "bf8";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
|
return "fp8bf16";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
|
|
return "fp8fp32";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp8>)
|
|
return "mxfp8";
|
|
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp4>)
|
|
return "mxfp4";
|
|
else
|
|
static_assert(false);
|
|
}();
|
|
|
|
if(nhead_k < 0)
|
|
nhead_k = nhead;
|
|
if(nhead % nhead_k != 0)
|
|
{
|
|
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
if(hdim_q % ck_tile::numeric_traits<QDataType>::PackedSize != 0)
|
|
{
|
|
std::cerr << "hdim_q is made even for fp4 Q data type" << std::endl;
|
|
hdim_q =
|
|
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<QDataType>::PackedSize);
|
|
}
|
|
if(hdim_q % ck_tile::numeric_traits<KDataType>::PackedSize != 0)
|
|
{
|
|
std::cerr << "hdim_q is made even for fp4 K data type" << std::endl;
|
|
hdim_q =
|
|
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<KDataType>::PackedSize);
|
|
}
|
|
if(is_mx && !seqlen_kpads.empty() && seqlen_kpads[0] > 0)
|
|
{
|
|
std::cerr
|
|
<< "seqlen_kpads is not supported with MX types. ignoring the 'seqlen_kpads' option"
|
|
<< std::endl;
|
|
seqlen_kpads = {-1};
|
|
}
|
|
|
|
std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}());
|
|
auto next_seed = [&random_engine]() { return static_cast<unsigned int>(random_engine()); };
|
|
|
|
if(hdim_v < 0)
|
|
hdim_v = hdim_q;
|
|
|
|
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(seqlen_knew != 0)
|
|
{
|
|
std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
|
|
<< std::endl;
|
|
seqlen_knew = 0;
|
|
}
|
|
#endif
|
|
if(seqlen_knew < 0)
|
|
{
|
|
seqlen_knew = randint<ck_tile::index_t>(1, seqlen_qs[0], random_engine);
|
|
}
|
|
|
|
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
|
|
{
|
|
if(0 < rotary_dim)
|
|
{
|
|
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
}
|
|
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
|
else if(0 < rotary_dim)
|
|
{
|
|
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
|
|
<< std::endl;
|
|
rotary_dim = 0;
|
|
}
|
|
#endif
|
|
// to use fmha_fwd_appendkv(), make sure it's in batch mode
|
|
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
|
if(need_append_kvcache && mode == mode_enum::group)
|
|
{
|
|
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
|
|
mode = mode_enum::batch;
|
|
}
|
|
if(!(rotary_dim <= hdim_q))
|
|
{
|
|
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
else if(!(rotary_dim % 16 == 0))
|
|
{
|
|
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
|
|
CK_TILE_FMHA_FWD_PAGEDKV_API))
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
|
<< std::endl;
|
|
page_block_size = 0;
|
|
}
|
|
#endif
|
|
if(!(page_block_size % 128 == 0))
|
|
{
|
|
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
|
<< std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API)
|
|
if(use_cache_batch_idx)
|
|
{
|
|
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
#else
|
|
if(use_cache_batch_idx)
|
|
{
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
|
"'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
else if(mode == mode_enum::group)
|
|
{
|
|
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
|
|
"'cache_batch_idx' option"
|
|
<< std::endl;
|
|
use_cache_batch_idx = false;
|
|
}
|
|
}
|
|
#endif
|
|
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
|
|
|
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
|
|
const bool has_group_q_padding =
|
|
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0);
|
|
const bool has_group_k_padding =
|
|
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0);
|
|
const bool has_group_padding = has_group_q_padding || has_group_k_padding;
|
|
const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty();
|
|
const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty();
|
|
const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding;
|
|
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
|
|
const bool using_pagedkv = (0 < page_block_size);
|
|
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
|
|
if((using_appendkv || using_pagedkv || using_splitkv) &&
|
|
(has_group_padding || has_batch_padding))
|
|
{
|
|
std::cerr << "Padding (physical or effective lengths) is not supported with "
|
|
"appendkv/splitkv/pagedkv pipelines"
|
|
<< std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
|
|
generate_missing_seqlens(mode,
|
|
batch,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
|
need_append_kvcache,
|
|
random_engine);
|
|
|
|
if(ck_tile::numeric_traits<VDataType>::PackedSize != 0)
|
|
{
|
|
// Ensure that all seqlens are even if V has packed data type
|
|
for(auto& s : seqlen_ks)
|
|
{
|
|
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
|
|
}
|
|
for(auto& s : kv_eff_lens_per_batch)
|
|
{
|
|
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
|
|
}
|
|
}
|
|
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb])
|
|
{
|
|
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
|
|
{
|
|
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
}
|
|
|
|
// compute kvcache seqlen_k (before appending knew/vnew)
|
|
auto cache_seqlen_ks = seqlen_ks;
|
|
std::transform(cache_seqlen_ks.begin(),
|
|
cache_seqlen_ks.end(),
|
|
cache_seqlen_ks.begin(),
|
|
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
|
|
|
|
#if 0
|
|
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
|
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
|
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
|
|
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
|
|
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
|
|
#endif
|
|
|
|
if(scale_s == .0f)
|
|
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
|
|
|
bias_info bias = bias_info::decode(bias_str);
|
|
|
|
mask_info mask =
|
|
mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
|
|
|
|
quant_scale_info qscale = quant_scale_info::decode(qscale_str);
|
|
|
|
if(is_mx && qscale.type != quant_scale_enum::mx)
|
|
{
|
|
std::cerr << "The value of qscale_str must be 'mx' for MX data types" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
else if(!is_mx && qscale.type == quant_scale_enum::mx)
|
|
{
|
|
std::cerr << "The value of qscale_str cannot be 'mx' for non-MX data types" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
if(is_mx && is_v_rowmajor)
|
|
{
|
|
std::cerr << "The value of is_v_rowmajor must be 'false' for MX data types" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
if(p_drop < 0.0f || p_drop > 1.0f)
|
|
{
|
|
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
|
|
bool s_randval = false;
|
|
if(p_drop > 0.0f && do_validation)
|
|
{
|
|
s_randval = true;
|
|
}
|
|
|
|
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
|
if(num_splits != 1)
|
|
{
|
|
std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
|
|
num_splits = 1;
|
|
}
|
|
#endif
|
|
|
|
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
|
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
|
const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads);
|
|
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
|
|
|
// Optional batch-mode cumulative seqlen overrides
|
|
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
|
if(mode == mode_enum::batch)
|
|
{
|
|
auto calculate_cumulative = [&](std::vector<ck_tile::index_t>& per_batch_vec,
|
|
std::vector<ck_tile::index_t>& cum_vec) {
|
|
if(!per_batch_vec.empty() && per_batch_vec[0] != -1)
|
|
{
|
|
if(per_batch_vec.size() < static_cast<size_t>(batch))
|
|
{
|
|
per_batch_vec.resize(batch, per_batch_vec.back());
|
|
}
|
|
cum_vec.resize(batch + 1);
|
|
cum_vec[0] = 0;
|
|
for(int i = 0; i < batch; ++i)
|
|
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
|
}
|
|
};
|
|
|
|
calculate_cumulative(q_eff_lens_per_batch, cuq_cum);
|
|
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
|
|
}
|
|
|
|
// accumulation numbers for performance evaluation
|
|
std::size_t flop = 0, num_byte = 0;
|
|
auto max_seqlen_q =
|
|
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
|
int32_t i_block_scale_q = 0;
|
|
int32_t i_block_scale_k = 0;
|
|
int32_t i_seqstart_v_scale = 0;
|
|
std::vector<int32_t> block_scale_seqstart_q_host = {0};
|
|
std::vector<int32_t> block_scale_seqstart_k_host = {0};
|
|
std::vector<int32_t> seqstart_v_scale_host = {0};
|
|
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
|
{
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
|
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
|
|
|
if(max_seqlen_q < real_seqlen_q)
|
|
{
|
|
max_seqlen_q = real_seqlen_q;
|
|
}
|
|
|
|
if(max_seqlen_k < real_seqlen_k)
|
|
{
|
|
max_seqlen_k = real_seqlen_k;
|
|
}
|
|
if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
|
i_block_scale_k +=
|
|
ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
|
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
|
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
|
}
|
|
else if(qscale.type == quant_scale_enum::mx)
|
|
{
|
|
i_seqstart_v_scale +=
|
|
ck_tile::integer_divide_ceil(real_seqlen_k, kVScaleGranularity);
|
|
seqstart_v_scale_host.push_back(i_seqstart_v_scale);
|
|
}
|
|
|
|
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
|
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
|
|
|
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q /
|
|
ck_tile::numeric_traits<QDataType>::PackedSize +
|
|
sizeof(ODataType) * real_seqlen_q * hdim_v);
|
|
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q /
|
|
ck_tile::numeric_traits<KDataType>::PackedSize +
|
|
sizeof(VDataType) * hdim_v * real_seqlen_k /
|
|
ck_tile::numeric_traits<VDataType>::PackedSize);
|
|
}
|
|
}
|
|
|
|
const ck_tile::index_t max_num_page_blocks =
|
|
(0 < page_block_size
|
|
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
|
: 0);
|
|
|
|
// legalize num_splits according to other options
|
|
if(num_splits < 1)
|
|
{
|
|
num_splits = override_num_splits_if_necessary(
|
|
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
|
|
}
|
|
if(128 < num_splits)
|
|
{
|
|
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
|
|
return fwd_result::invalid_args;
|
|
}
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < p_drop && (1 < num_splits || use_kvcache))
|
|
{
|
|
std::cerr << "dropout is not supported by split-kv kernels. ignoring the 'p_drop' option"
|
|
<< std::endl;
|
|
p_drop = 0.0f;
|
|
}
|
|
#endif
|
|
|
|
static const auto get_lengths = [](bool permute,
|
|
ck_tile::index_t b /*batch*/,
|
|
ck_tile::index_t h /*nhead*/,
|
|
ck_tile::index_t s /*seqlen*/,
|
|
ck_tile::index_t d /*hdim*/) {
|
|
if(permute)
|
|
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
|
else
|
|
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
|
};
|
|
|
|
// host memory for storing all the tensor elements
|
|
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
|
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
|
const ck_tile::index_t shape_seqlen_q =
|
|
(mode == mode_enum::batch ? seqlen_qs[0]
|
|
: (has_group_q_padding && !seqstart_q_with_padding_host.empty()
|
|
? seqstart_q_with_padding_host.back()
|
|
: seqstart_q_host.back()));
|
|
const ck_tile::index_t shape_seqlen_k =
|
|
(mode == mode_enum::batch ? seqlen_ks[0]
|
|
: (has_group_k_padding && !seqstart_k_with_padding_host.empty()
|
|
? seqstart_k_with_padding_host.back()
|
|
: seqstart_k_host.back()));
|
|
|
|
const ck_tile::index_t num_block_scale_q =
|
|
(mode == mode_enum::batch)
|
|
? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_)
|
|
: i_block_scale_q;
|
|
const ck_tile::index_t num_block_scale_kv =
|
|
(mode == mode_enum::batch)
|
|
? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_)
|
|
: i_block_scale_k;
|
|
|
|
ck_tile::HostTensor<QDataType> q_host(
|
|
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
|
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
|
ck_tile::HostTensor<KDataType> k_host(
|
|
0 < page_block_size
|
|
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
|
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
|
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
|
ck_tile::HostTensor<KDataType> knew_host(
|
|
0 < seqlen_knew
|
|
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
ck_tile::HostTensor<VDataType> v_host(
|
|
0 < page_block_size
|
|
? (is_v_rowmajor
|
|
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
|
|
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
|
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
|
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
|
ck_tile::HostTensor<VDataType> vnew_host(
|
|
0 < seqlen_knew
|
|
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
|
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
ck_tile::HostTensor<BiasDataType> bias_host(
|
|
bias.type == bias_enum::elementwise_bias
|
|
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
|
|
|
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
|
|
bias.type == bias_enum::alibi
|
|
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
|
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
|
: std::array<ck_tile::index_t, 2>{1, 1});
|
|
|
|
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
|
|
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, next_seed());
|
|
|
|
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
|
1 < num_splits || use_kvcache
|
|
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
|
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
|
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
|
|
nhead,
|
|
num_splits,
|
|
shape_seqlen_q,
|
|
hdim_v}
|
|
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
|
|
|
const ck_tile::index_t hdim_q_scale = ck_tile::integer_divide_ceil(hdim_q, kQKScaleGranularity);
|
|
const ck_tile::index_t shape_seqlen_v_scale = seqstart_v_scale_host.back();
|
|
|
|
ck_tile::HostTensor<QScaleDataType> q_descale_host({1});
|
|
ck_tile::HostTensor<KScaleDataType> k_descale_host({1});
|
|
ck_tile::HostTensor<VScaleDataType> v_descale_host({1});
|
|
if constexpr(is_mx)
|
|
{
|
|
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
|
|
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q_scale));
|
|
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
|
|
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q_scale));
|
|
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
|
|
get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_v_scale));
|
|
}
|
|
else if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
|
|
std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q});
|
|
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
|
|
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
|
|
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
|
|
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
|
|
}
|
|
|
|
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
|
// group mode of lse data layout is [nhead, total_seqlen_q]
|
|
ck_tile::HostTensor<LSEDataType> lse_host(
|
|
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
|
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
|
|
|
ck_tile::HostTensor<ODataType> o_host(
|
|
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
|
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
|
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
|
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
|
|
|
ck_tile::HostTensor<int32_t> block_table_host(
|
|
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
|
|
: std::array<ck_tile::index_t, 2>{1, 1});
|
|
|
|
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
|
? std::array<ck_tile::index_t, 1>{batch}
|
|
: std::array<ck_tile::index_t, 1>{1});
|
|
if(init_method == "ui" || init_method == "0")
|
|
{
|
|
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
|
bias_host);
|
|
}
|
|
|
|
else if(init_method == "ni")
|
|
{
|
|
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
|
bias_host);
|
|
}
|
|
else if(init_method == "uf" || init_method == "1")
|
|
{
|
|
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, next_seed()}(q_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, next_seed()}(k_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, next_seed()}(knew_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, next_seed()}(v_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, next_seed()}(vnew_host);
|
|
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
|
|
}
|
|
else if(init_method == "nf")
|
|
{
|
|
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, next_seed()}(q_host);
|
|
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, next_seed()}(k_host);
|
|
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, next_seed()}(knew_host);
|
|
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, next_seed()}(v_host);
|
|
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, next_seed()}(vnew_host);
|
|
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, next_seed()}(bias_host);
|
|
}
|
|
else if(init_method == "tf" || init_method == "2")
|
|
{
|
|
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
|
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
|
ck_tile::FillTrigValue<KDataType>{}(knew_host);
|
|
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
|
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
|
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
|
}
|
|
else if(init_method == "3")
|
|
{
|
|
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
|
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
|
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
|
|
|
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, next_seed()}(q_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(k_host);
|
|
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(
|
|
knew_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(v_host);
|
|
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(
|
|
vnew_host);
|
|
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
|
|
}
|
|
if(bias.type == bias_enum::alibi)
|
|
{
|
|
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
|
assert(slopes.size() == static_cast<std::size_t>(nhead));
|
|
if(bias.rank_info == 0)
|
|
{
|
|
// alibi in 1*h
|
|
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
|
|
}
|
|
else
|
|
{
|
|
// alibi in b*h
|
|
for(auto i_b = 0; i_b < batch; i_b++)
|
|
{
|
|
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
|
|
}
|
|
}
|
|
}
|
|
if constexpr(is_mx)
|
|
{
|
|
auto gen_scales = [&](auto& scales, auto data, float range) {
|
|
using DataType = decltype(data);
|
|
using ScaleType = ck_tile::remove_cvref_t<decltype(*scales.begin())>;
|
|
if constexpr(std::is_same_v<ScaleType, ck_tile::e8m0_t>)
|
|
{
|
|
const float base =
|
|
-std::log2(ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max()));
|
|
// e8m0_t is basically an exponent of float32
|
|
// When scales are applied to tensor values, value * exp2(base - range) is around
|
|
// 0.125 and value * exp2(base + range) is around 8 for all types (fp8/bf8/fp4)
|
|
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
|
ck_tile::FillUniformDistributionIntegerValue<float>{
|
|
base - range, base + range, next_seed()}(pow2);
|
|
scales.ForEach([&](auto& self, const auto& i) {
|
|
self(i) = ck_tile::type_convert<ScaleType>(std::exp2(pow2(i)));
|
|
});
|
|
}
|
|
else
|
|
{
|
|
static_assert(false);
|
|
}
|
|
};
|
|
gen_scales(q_descale_host, QDataType{}, 3);
|
|
gen_scales(k_descale_host, KDataType{}, 3);
|
|
// When P is fp4, only 8 values (0, 0.5, 1, 1.5, 2, 3, 4, 6) are used to quantize P.
|
|
// Too large V values can create rare error outliers between host (no quantization) and
|
|
// device ("running" FA softmax + quantization), here we reduce max value by using smaller
|
|
// range of V scales.
|
|
gen_scales(v_descale_host,
|
|
VDataType{},
|
|
std::is_same_v<typename TypeConfig::PDataType, ck_tile::pk_fp4_t> ? 1 : 3);
|
|
}
|
|
else if(qscale.type == quant_scale_enum::pertensor)
|
|
{
|
|
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
|
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
|
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
|
|
|
float qkv_max = 3.f;
|
|
q_descale_host(0) = qkv_max / q_dtype_max;
|
|
k_descale_host(0) = qkv_max / k_dtype_max;
|
|
v_descale_host(0) = qkv_max / v_dtype_max;
|
|
}
|
|
else if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
|
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
|
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
|
|
|
float qkv_max = 3.f;
|
|
float max_descale_q = qkv_max / q_dtype_max;
|
|
float max_descale_k = qkv_max / k_dtype_max;
|
|
float max_descale_v = qkv_max / v_dtype_max;
|
|
|
|
ck_tile::FillUniformDistribution<float>{max_descale_q * 0.8f, max_descale_q, next_seed()}(
|
|
q_descale_host);
|
|
ck_tile::FillUniformDistribution<float>{max_descale_k * 0.8f, max_descale_k, next_seed()}(
|
|
k_descale_host);
|
|
ck_tile::FillUniformDistribution<float>{max_descale_v * 0.8f, max_descale_v, next_seed()}(
|
|
v_descale_host);
|
|
}
|
|
|
|
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
|
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
|
if(init_sink_value != 0)
|
|
{
|
|
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
|
// for close to rowmax values.
|
|
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
|
sink_host);
|
|
}
|
|
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() *
|
|
sizeof(int32_t));
|
|
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
|
|
sizeof(int32_t));
|
|
ck_tile::DeviceMem scale_seqstart_v_buf(seqstart_v_scale_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem seqstart_q_buf(seqstart_q_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_k_buf(seqstart_k_host.size() * sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
|
|
? 0
|
|
: seqstart_q_with_padding_host.size() *
|
|
sizeof(int32_t));
|
|
ck_tile::DeviceMem seqstart_k_padded_buf(
|
|
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
|
|
// Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding
|
|
// enabled)
|
|
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
|
|
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
|
|
// kvcache or group mode with padding enabled)
|
|
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
|
? seqlen_ks.size() * sizeof(int32_t)
|
|
: 0);
|
|
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
|
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
|
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
|
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
|
|
ck_tile::DeviceMem cache_seqlen_k_buf(
|
|
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
|
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
|
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
|
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
|
|
|
q_buf.ToDevice(q_host.data());
|
|
k_buf.ToDevice(k_host.data());
|
|
v_buf.ToDevice(v_host.data());
|
|
sink_buf.ToDevice(sink_host.data());
|
|
knew_buf.ToDevice(knew_host.data());
|
|
vnew_buf.ToDevice(vnew_host.data());
|
|
bias_buf.ToDevice(bias_host.data());
|
|
q_descale_buf.ToDevice(q_descale_host.data());
|
|
k_descale_buf.ToDevice(k_descale_host.data());
|
|
v_descale_buf.ToDevice(v_descale_host.data());
|
|
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
|
|
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
|
|
scale_seqstart_v_buf.ToDevice(seqstart_v_scale_host.data());
|
|
seqstart_q_buf.ToDevice(seqstart_q_host.data());
|
|
// Keep logical starts in seqstart_k_buf; pass padded K via separate pointer
|
|
seqstart_k_buf.ToDevice(seqstart_k_host.data());
|
|
seqstart_q_padded_buf.ToDevice(
|
|
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
|
|
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
|
|
: seqstart_k_with_padding_host.data());
|
|
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
|
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
|
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
|
|
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
|
? seqlen_ks.data()
|
|
: nullptr);
|
|
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
|
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
|
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
|
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
|
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
|
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
|
block_table_buf.ToDevice(block_table_host.data());
|
|
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
|
|
|
// clang-format off
|
|
auto layout_str = [&](bool permute){
|
|
if(permute) return std::string("bhsd");
|
|
else return std::string("bshd");
|
|
};
|
|
auto io_layout = [&](bool iperm_, bool operm_) {
|
|
if(iperm_ == operm_) return layout_str(iperm_);
|
|
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
|
|
};
|
|
// clang-format on
|
|
|
|
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
|
|
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
|
<< "/" << seqlen_ks[0]
|
|
<< (seqlen_kpads[0] < 0 ? ""
|
|
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
|
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
|
<< ", p_drop:" << p_drop << ", lse:" << lse << ", qscale:" << qscale
|
|
<< ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c");
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(0 < rotary_dim)
|
|
{
|
|
std::cout << ", rotary_dim:" << rotary_dim << "("
|
|
<< (is_rotary_interleaved ? "inter" : "half") << ")";
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(1 < num_splits)
|
|
{
|
|
std::cout << ", num_splits:" << num_splits;
|
|
}
|
|
if(0 < page_block_size)
|
|
{
|
|
std::cout << ", page_block_size:" << page_block_size;
|
|
}
|
|
if(use_cache_batch_idx)
|
|
{
|
|
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
|
|
}
|
|
#endif
|
|
// Padding / effective length diagnostic logging
|
|
auto print_vec = [&](const char* label, const std::vector<int>& v) {
|
|
if(v.empty())
|
|
return;
|
|
std::cout << ", " << label << ":[";
|
|
for(std::size_t i = 0; i < v.size(); ++i)
|
|
{
|
|
if(i)
|
|
std::cout << ",";
|
|
std::cout << v[i];
|
|
}
|
|
std::cout << "]";
|
|
};
|
|
|
|
if(has_group_padding)
|
|
{
|
|
bool has_qpad = !seqstart_q_with_padding_host.empty();
|
|
bool has_kpad = (seqlen_kpads[0] >= 0);
|
|
if(has_qpad)
|
|
{
|
|
print_vec("q_logical", seqlen_qs);
|
|
print_vec("q_padded", seqlen_qpads);
|
|
}
|
|
if(has_kpad)
|
|
{
|
|
print_vec("k_logical", seqlen_ks);
|
|
print_vec("k_padded", seqlen_kpads);
|
|
}
|
|
}
|
|
else if(has_batch_padding)
|
|
{
|
|
// derive effective lengths from cumulative arrays if present
|
|
if(!cuq_cum.empty())
|
|
{
|
|
std::vector<int> eff_q(batch);
|
|
for(int b_i = 0; b_i < batch; ++b_i)
|
|
eff_q[b_i] = static_cast<int>(cuq_cum[b_i + 1] - cuq_cum[b_i]);
|
|
print_vec("q_eff", eff_q);
|
|
}
|
|
if(!cukv_cum.empty())
|
|
{
|
|
std::vector<int> eff_kv(batch);
|
|
for(int b_i = 0; b_i < batch; ++b_i)
|
|
eff_kv[b_i] = static_cast<int>(cukv_cum[b_i + 1] - cukv_cum[b_i]);
|
|
print_vec("kv_eff", eff_kv);
|
|
}
|
|
}
|
|
|
|
std::cout << std::flush;
|
|
|
|
const auto init_traits = [&](auto& traits) {
|
|
traits.hdim_q = hdim_q;
|
|
traits.hdim_v = hdim_v;
|
|
traits.data_type = data_type;
|
|
traits.is_v_rowmajor = is_v_rowmajor;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
|
|
: rope_enum::half_rotated)
|
|
: rope_enum::none);
|
|
}
|
|
else // fmha_fwd_traits or fmha_splitkv_traits
|
|
{
|
|
traits.is_group_mode = (mode == mode_enum::group);
|
|
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
|
traits.mask_type = mask.type;
|
|
traits.bias_type = bias.type;
|
|
traits.has_sink = mask.sink > 0 ? true : false;
|
|
traits.has_lse = lse;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.has_dropout = (p_drop > 0.0f);
|
|
traits.qscale_type = qscale.type;
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
|
|
std::decay_t<decltype(traits)>>)
|
|
{
|
|
traits.use_pagedkv = (0 < page_block_size);
|
|
}
|
|
}
|
|
};
|
|
|
|
const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
|
|
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
|
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
|
/// 'nhead_stride_bias' are 0.
|
|
// setup stride_* arguments
|
|
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
|
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
|
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
|
|
const ck_tile::index_t stride_v = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? hdim_v : nhead_k * hdim_v;
|
|
else
|
|
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
|
|
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
|
}();
|
|
const ck_tile::index_t stride_vnew = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? hdim_v : nhead_k * hdim_v;
|
|
else
|
|
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
|
}();
|
|
const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k);
|
|
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
|
const ck_tile::index_t stride_o_acc = (hdim_v);
|
|
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
|
// setup nhead_stride_* arguments
|
|
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
|
const ck_tile::index_t nhead_stride_k =
|
|
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
|
|
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
|
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
|
|
const ck_tile::index_t nhead_stride_v = [&]() {
|
|
if(is_v_rowmajor)
|
|
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
|
|
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
|
else
|
|
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
|
|
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
|
}();
|
|
const ck_tile::index_t nhead_stride_vnew = [&]() {
|
|
if(is_v_rowmajor)
|
|
return i_perm ? seqlen_knew * hdim_v : hdim_v;
|
|
else
|
|
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
|
}();
|
|
const ck_tile::index_t nhead_stride_bias =
|
|
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
|
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
|
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
|
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
|
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q;
|
|
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv;
|
|
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv;
|
|
// setup batch_stride_* arguments
|
|
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
|
const ck_tile::index_t batch_stride_k =
|
|
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
|
: (nhead_k * shape_seqlen_k * hdim_q));
|
|
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
|
|
const ck_tile::index_t batch_stride_v =
|
|
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
|
: (nhead_k * hdim_v * shape_seqlen_k));
|
|
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
|
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
|
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
|
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
|
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
|
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
|
const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead;
|
|
const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k;
|
|
const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k;
|
|
// setup split_stride_* arguments (only used in split-kv kernel)
|
|
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
|
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
|
|
|
args.q_ptr = q_buf.GetDeviceBuffer();
|
|
args.k_ptr = k_buf.GetDeviceBuffer();
|
|
args.v_ptr = v_buf.GetDeviceBuffer();
|
|
if(init_sink_value != 0)
|
|
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
|
else
|
|
args.sink_ptr = nullptr;
|
|
args.batch = batch;
|
|
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
|
args.hdim_q = hdim_q;
|
|
args.hdim_v = hdim_v;
|
|
args.nhead_q = nhead;
|
|
args.nhead_k = nhead_k;
|
|
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.num_head_q_total = nhead;
|
|
args.head_start = 0;
|
|
}
|
|
|
|
args.stride_q = stride_q;
|
|
args.stride_k = stride_k;
|
|
args.stride_v = stride_v;
|
|
args.nhead_stride_q = nhead_stride_q;
|
|
args.nhead_stride_k = nhead_stride_k;
|
|
args.nhead_stride_v = nhead_stride_v;
|
|
args.batch_stride_q = batch_stride_q;
|
|
args.batch_stride_k = batch_stride_k;
|
|
args.batch_stride_v = batch_stride_v;
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.knew_ptr = knew_buf.GetDeviceBuffer();
|
|
args.vnew_ptr = vnew_buf.GetDeviceBuffer();
|
|
args.seqlen_knew = seqlen_knew;
|
|
|
|
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
|
|
|
|
args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr);
|
|
args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr);
|
|
args.rotary_dim = rotary_dim;
|
|
args.has_mask = (mask.type != mask_enum::no_mask);
|
|
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
|
|
args.stride_knew = stride_knew;
|
|
args.stride_vnew = stride_vnew;
|
|
args.nhead_stride_knew = nhead_stride_knew;
|
|
args.nhead_stride_vnew = nhead_stride_vnew;
|
|
args.batch_stride_knew = batch_stride_knew;
|
|
args.batch_stride_vnew = batch_stride_vnew;
|
|
}
|
|
else // fmha_fwd_args or fmha_fwd_splitkv_args
|
|
{
|
|
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
|
: bias_buf.GetDeviceBuffer();
|
|
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
|
args.o_ptr = o_buf.GetDeviceBuffer();
|
|
|
|
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
|
args.max_seqlen_q = max_seqlen_q;
|
|
|
|
args.scale_s = scale_s;
|
|
|
|
args.logits_soft_cap = logits_soft_cap;
|
|
|
|
args.stride_bias =
|
|
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
|
|
args.stride_o = stride_o;
|
|
args.nhead_stride_bias = nhead_stride_bias;
|
|
args.nhead_stride_lse = nhead_stride_lse;
|
|
args.nhead_stride_o = nhead_stride_o;
|
|
args.batch_stride_bias = batch_stride_bias;
|
|
args.batch_stride_lse = batch_stride_lse;
|
|
args.batch_stride_o = batch_stride_o;
|
|
|
|
args.window_size_left = mask.left;
|
|
args.window_size_right = mask.right;
|
|
args.sink_size = mask.sink;
|
|
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
|
|
|
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
if(qscale.type == quant_scale_enum::pertensor)
|
|
{
|
|
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
|
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
|
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
|
}
|
|
else if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
args.q_descale_ptr =
|
|
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
|
args.k_descale_ptr =
|
|
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
|
|
args.v_descale_ptr =
|
|
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
|
|
|
|
args.block_scale_seqstart_q_ptr =
|
|
(mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer()
|
|
: nullptr);
|
|
args.block_scale_seqstart_k_ptr =
|
|
(mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer()
|
|
: nullptr);
|
|
|
|
args.nhead_stride_q_descale = nhead_stride_q_descale;
|
|
args.nhead_stride_k_descale = nhead_stride_k_descale;
|
|
args.nhead_stride_v_descale = nhead_stride_v_descale;
|
|
|
|
args.batch_stride_q_descale = batch_stride_q_descale;
|
|
args.batch_stride_k_descale = batch_stride_k_descale;
|
|
args.batch_stride_v_descale = batch_stride_v_descale;
|
|
|
|
args.block_scale_size_q = block_scale_size_q_;
|
|
args.block_scale_size_kv = block_scale_size_kv_;
|
|
}
|
|
else if(qscale.type == quant_scale_enum::mx)
|
|
{
|
|
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
|
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
|
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
|
|
|
args.stride_q_descale = (i_perm ? hdim_q_scale : nhead * hdim_q_scale);
|
|
args.stride_k_descale = (i_perm ? hdim_q_scale : nhead_k * hdim_q_scale);
|
|
args.stride_v_descale =
|
|
(i_perm ? shape_seqlen_v_scale : nhead_k * shape_seqlen_v_scale);
|
|
args.nhead_stride_q_descale =
|
|
(i_perm ? shape_seqlen_q * hdim_q_scale : hdim_q_scale);
|
|
args.nhead_stride_k_descale =
|
|
(i_perm ? shape_seqlen_k * hdim_q_scale : hdim_q_scale);
|
|
args.nhead_stride_v_descale =
|
|
(i_perm ? hdim_v * shape_seqlen_v_scale : shape_seqlen_v_scale);
|
|
if(mode == mode_enum::group)
|
|
{
|
|
args.seqstart_v_scale_ptr = scale_seqstart_v_buf.GetDeviceBuffer();
|
|
}
|
|
else
|
|
{
|
|
args.batch_stride_q_descale = (nhead * shape_seqlen_q * hdim_q_scale);
|
|
args.batch_stride_k_descale = (nhead_k * shape_seqlen_k * hdim_q_scale);
|
|
args.batch_stride_v_descale = (nhead_k * hdim_v * shape_seqlen_v_scale);
|
|
}
|
|
}
|
|
|
|
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
|
|
|
args.stride_randval = stride_randval;
|
|
args.nhead_stride_randval = nhead_stride_randval;
|
|
args.batch_stride_randval = batch_stride_randval;
|
|
|
|
args.p_drop = p_drop;
|
|
args.s_randval = s_randval;
|
|
if(drop_prefs)
|
|
{
|
|
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
|
drop_offset_buf.GetDeviceBuffer());
|
|
}
|
|
else
|
|
{
|
|
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
|
}
|
|
|
|
// Sequence length and padding parameters (mode-specific)
|
|
if(mode == mode_enum::group)
|
|
{
|
|
// Group mode: use physical (padded) cumulative starts + logical per-sequence
|
|
// lengths
|
|
|
|
// Physical cumulative starts (including padding)
|
|
args.seqstart_q_ptr =
|
|
has_group_q_padding && !seqstart_q_with_padding_host.empty()
|
|
? seqstart_q_padded_buf.GetDeviceBuffer()
|
|
: seqstart_q_buf.GetDeviceBuffer();
|
|
args.seqstart_k_ptr =
|
|
has_group_k_padding && !seqstart_k_with_padding_host.empty()
|
|
? seqstart_k_padded_buf.GetDeviceBuffer()
|
|
: seqstart_k_buf.GetDeviceBuffer();
|
|
|
|
// Logical (unpadded) per-sequence lengths, used when padding is enabled
|
|
args.seqlen_q_ptr =
|
|
(has_group_q_padding && !seqstart_q_with_padding_host.empty())
|
|
? seqlen_q_buf.GetDeviceBuffer()
|
|
: nullptr;
|
|
args.seqlen_k_ptr =
|
|
(has_group_k_padding && !seqstart_k_with_padding_host.empty())
|
|
? seqlen_k_buf.GetDeviceBuffer()
|
|
: nullptr;
|
|
// Cumulative lengths not used in group mode
|
|
args.cu_seqlen_q_ptr = nullptr;
|
|
args.cu_seqlen_k_ptr = nullptr;
|
|
}
|
|
else // mode == mode_enum::batch
|
|
{
|
|
// Batch mode: use cumulative logical lengths for tail padding
|
|
|
|
// seqstart pointers not used in batch mode
|
|
args.seqstart_q_ptr = nullptr;
|
|
args.seqstart_k_ptr = nullptr;
|
|
|
|
// seqlen_q_ptr/seqlen_k_ptr not used in batch mode
|
|
args.seqlen_q_ptr = nullptr;
|
|
args.seqlen_k_ptr = nullptr;
|
|
|
|
// Cumulative logical lengths for effective length handling
|
|
args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty()
|
|
? cu_seqlen_q_buf.GetDeviceBuffer()
|
|
: nullptr;
|
|
args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty()
|
|
? cu_seqlen_kv_buf.GetDeviceBuffer()
|
|
: nullptr;
|
|
}
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
|
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
|
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
args.is_gappy = false; // use 'false' for flash-attention integration
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
|
|
args.num_splits = num_splits;
|
|
|
|
args.stride_o_acc = stride_o_acc;
|
|
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
|
|
args.nhead_stride_o_acc = nhead_stride_o_acc;
|
|
args.batch_stride_lse_acc = batch_stride_lse_acc;
|
|
args.batch_stride_o_acc = batch_stride_o_acc;
|
|
args.split_stride_lse_acc = split_stride_lse_acc;
|
|
args.split_stride_o_acc = split_stride_o_acc;
|
|
|
|
args.seqstart_q_ptr =
|
|
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
|
args.seqstart_k_ptr =
|
|
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
|
|
args.seqlen_k_ptr =
|
|
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
|
? seqlen_k_buf.GetDeviceBuffer()
|
|
: nullptr);
|
|
}
|
|
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
|
|
{
|
|
args.block_table_ptr =
|
|
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
|
args.batch_stride_block_table = batch_stride_block_table;
|
|
args.page_block_size = page_block_size;
|
|
args.is_gappy = false; // use 'false' for flash-attention integration
|
|
|
|
args.cache_batch_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
|
|
|
args.seqstart_q_ptr =
|
|
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
|
args.seqstart_k_ptr =
|
|
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
|
|
args.seqlen_k_ptr =
|
|
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
|
? seqlen_k_buf.GetDeviceBuffer()
|
|
: nullptr);
|
|
}
|
|
}
|
|
};
|
|
|
|
auto run_appendkv = [&]([[maybe_unused]] const ck_tile::stream_config& sc) {
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
if(need_append_kvcache)
|
|
{
|
|
fmha_fwd_appendkv_traits fwd_appendkv_traits;
|
|
init_traits(fwd_appendkv_traits);
|
|
|
|
fmha_fwd_appendkv_args fwd_appendkv_args;
|
|
init_args(fwd_appendkv_args);
|
|
|
|
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
|
|
}
|
|
#endif
|
|
return 0.0f;
|
|
};
|
|
const float appendkv_ave_time = run_appendkv(stream_config);
|
|
if(appendkv_ave_time < 0.0f)
|
|
{
|
|
std::cout << ", not supported yet" << std::flush << std::endl;
|
|
return fwd_result::no_instance;
|
|
}
|
|
|
|
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
|
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(1 == num_splits && use_kvcache)
|
|
{
|
|
fmha_fwd_pagedkv_traits fmha_pagedkv_traits;
|
|
init_traits(fmha_pagedkv_traits);
|
|
|
|
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
|
init_args(fmha_pagedkv_args);
|
|
|
|
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
// If there is no instance for these args, fallback to fmha_fwd_splitkv
|
|
if(ave_time >= 0.0f)
|
|
return ave_time;
|
|
#else
|
|
return ave_time;
|
|
#endif
|
|
}
|
|
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
if(1 < num_splits || use_kvcache)
|
|
{
|
|
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
|
init_traits(fmha_splitkv_traits);
|
|
|
|
fmha_fwd_splitkv_args fmha_splitkv_args;
|
|
init_args(fmha_splitkv_args);
|
|
|
|
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
|
|
}
|
|
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
|
fmha_fwd_traits fmha_traits;
|
|
init_traits(fmha_traits);
|
|
|
|
fmha_fwd_args fmha_args;
|
|
init_args(fmha_args);
|
|
|
|
return fmha_fwd(fmha_traits, fmha_args, sc);
|
|
};
|
|
|
|
float fwd_ave_time = -1.0f;
|
|
#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING
|
|
const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) &&
|
|
!need_append_kvcache &&
|
|
(mode == mode_enum::batch || mode == mode_enum::group);
|
|
|
|
if(allow_head_grouping)
|
|
{
|
|
if(fmha_fwd_head_grouping::disabled_by_env())
|
|
{
|
|
if(fmha_fwd_head_grouping::log_enabled())
|
|
std::cout << "[LLC Head Grouping] disabled by env" << std::endl;
|
|
}
|
|
else
|
|
{
|
|
const auto group_size_opt =
|
|
fmha_fwd_head_grouping::get_head_group_size(nhead,
|
|
nhead_k,
|
|
batch,
|
|
max_seqlen_k,
|
|
hdim_q,
|
|
hdim_v,
|
|
sizeof(KDataType),
|
|
sizeof(VDataType));
|
|
|
|
if(group_size_opt.has_value() && group_size_opt.value() < nhead)
|
|
{
|
|
if(fmha_fwd_head_grouping::log_enabled())
|
|
{
|
|
const std::string arch = ck_tile::get_device_name();
|
|
const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch);
|
|
const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1);
|
|
const ck_tile::index_t group_sz = group_size_opt.value();
|
|
const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz);
|
|
std::cout << "[LLC Head Grouping] enabled" << std::endl;
|
|
std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch)
|
|
<< " llc_mb=" << (llc_bytes / (1024ull * 1024ull))
|
|
<< " nhead_q=" << nhead << " nhead_k=" << nhead_k
|
|
<< " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz
|
|
<< " groups=" << n_groups << std::endl;
|
|
}
|
|
fmha_fwd_traits fmha_traits;
|
|
init_traits(fmha_traits);
|
|
|
|
fmha_fwd_args fmha_args;
|
|
init_args(fmha_args);
|
|
|
|
fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped<QDataType,
|
|
KDataType,
|
|
VDataType,
|
|
ODataType,
|
|
BiasDataType,
|
|
LSEDataType,
|
|
RandValOutputDataType>(
|
|
stream_config,
|
|
fmha_traits,
|
|
fmha_args,
|
|
nhead,
|
|
nhead_k,
|
|
group_size_opt.value(),
|
|
qscale.type == quant_scale_enum::blockscale,
|
|
[&](const auto& traits, auto& args, const auto& sc) {
|
|
return fmha_fwd(traits, args, sc);
|
|
});
|
|
}
|
|
else if(fmha_fwd_head_grouping::log_enabled())
|
|
{
|
|
std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)"
|
|
<< std::endl;
|
|
}
|
|
}
|
|
}
|
|
else if(fmha_fwd_head_grouping::log_enabled())
|
|
{
|
|
std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl;
|
|
}
|
|
#endif
|
|
|
|
if(fwd_ave_time < 0.0f)
|
|
fwd_ave_time = run_fwd(stream_config);
|
|
if(fwd_ave_time < 0.0f)
|
|
{
|
|
std::cout << ", not supported yet" << std::flush << std::endl;
|
|
return fwd_result::no_instance;
|
|
}
|
|
|
|
const float ave_time = appendkv_ave_time + fwd_ave_time;
|
|
const float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
|
const float gb_per_sec = num_byte / 1.E6 / ave_time;
|
|
if(stream_config.time_kernel_)
|
|
{
|
|
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
|
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
|
<< gb_per_sec << " GB/s" << std::flush;
|
|
}
|
|
|
|
bool pass = true;
|
|
if(do_validation == 0)
|
|
{
|
|
std::cout << std::flush << std::endl;
|
|
}
|
|
else if(do_validation == 2)
|
|
{
|
|
// NOTE: use gpu to do validation
|
|
ck_tile::naive_attention_fwd_traits naive_t;
|
|
naive_t.q_type = data_type;
|
|
naive_t.k_type = data_type;
|
|
naive_t.v_type = data_type;
|
|
naive_t.o_type = data_type;
|
|
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
|
|
naive_t.variation = 0; // TODO?
|
|
naive_t.quant_algo = 0;
|
|
|
|
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
|
|
|
|
ck_tile::naive_attention_fwd_args naive_a;
|
|
naive_a.q_ptr = q_buf.GetDeviceBuffer();
|
|
naive_a.k_ptr = k_buf.GetDeviceBuffer();
|
|
naive_a.v_ptr = v_buf.GetDeviceBuffer();
|
|
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
|
|
naive_a.scale_s = scale_s;
|
|
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
|
|
naive_a.page_table_ptr =
|
|
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
|
|
naive_a.hdim = hdim_q;
|
|
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
|
|
naive_a.batch_q = batch;
|
|
naive_a.batch_kv = batch;
|
|
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
|
|
naive_a.seqlen_q = seqlen_qs[0];
|
|
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
|
|
naive_a.nhead_q = nhead;
|
|
naive_a.nhead_kv = nhead_k;
|
|
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
|
|
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
|
|
|
|
ck_tile::stream_config naive_s{};
|
|
|
|
naive_attention_fwd(naive_t, naive_a, naive_s);
|
|
|
|
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
|
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
|
|
|
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
|
pass = ck_tile::check_err(
|
|
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
|
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
|
}
|
|
else
|
|
{
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
|
|
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
|
|
if(0 < rotary_dim && stream_config.time_kernel_)
|
|
{
|
|
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
|
|
q_buf.ToDevice(q_host.data());
|
|
run_appendkv(stream_config2);
|
|
run_fwd(stream_config2);
|
|
}
|
|
#endif
|
|
o_buf.FromDevice(o_host.data());
|
|
lse_buf.FromDevice(lse_host.data());
|
|
randval_buf.FromDevice(randval_host.data());
|
|
|
|
constexpr bool supports_qscale = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
|
|
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
|
|
|
|
float scale_s_host = scale_s;
|
|
float scale_p_host = 1.0f;
|
|
float scale_o_host = 1.0f;
|
|
|
|
if constexpr(!is_mx)
|
|
{
|
|
if(qscale.type == quant_scale_enum::pertensor)
|
|
{
|
|
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
|
|
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
|
scale_o_host = v_descale_host(0) / scale_p_host;
|
|
}
|
|
}
|
|
|
|
auto p_compute_element_func = [&]() {
|
|
if constexpr(supports_qscale)
|
|
return ck_tile::scales{scale_p_host};
|
|
else
|
|
return ck_tile::identity{};
|
|
}();
|
|
|
|
auto oacc_element_func = [&]() {
|
|
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
|
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
|
ck_tile::scales{scale_o_host});
|
|
else if constexpr(supports_qscale)
|
|
return ck_tile::scales{scale_o_host};
|
|
else
|
|
return ck_tile::identity{};
|
|
}();
|
|
|
|
float p_undrop = 1.0 - p_drop;
|
|
uint8_t p_undrop_in_uint8_t =
|
|
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
|
float rp_undrop = 1.0 / p_undrop;
|
|
|
|
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
|
{
|
|
ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
|
ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
|
if(mode == mode_enum::batch)
|
|
{
|
|
if(!cuq_cum.empty())
|
|
{
|
|
real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb];
|
|
}
|
|
if(!cukv_cum.empty())
|
|
{
|
|
real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb];
|
|
}
|
|
}
|
|
|
|
// adjust matrix index according to the mode
|
|
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
|
const ck_tile::index_t cache_b_idx =
|
|
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
|
// Use physical offset if padding info is valid (not -1) and buffers are available
|
|
const ck_tile::index_t query_offset =
|
|
(mode == mode_enum::batch
|
|
? 0
|
|
: ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0)
|
|
? seqstart_q_host[wb]
|
|
: seqstart_q_with_padding_host[wb]));
|
|
const ck_tile::index_t key_offset =
|
|
(mode == mode_enum::batch
|
|
? 0
|
|
: ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0)
|
|
? seqstart_k_host[wb]
|
|
: seqstart_k_with_padding_host[wb]));
|
|
|
|
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
|
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
|
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
|
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
|
|
|
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
|
|
|
ck_tile::index_t nr = nhead / nhead_k;
|
|
|
|
// clang-format off
|
|
// permute
|
|
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
|
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
|
// clang-format on
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// optionally apply RoPE to the q_host_ref
|
|
if(0 < rotary_dim)
|
|
{
|
|
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
|
|
|
auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(
|
|
rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
|
|
|
|
ck_tile::reference_batched_rotary_position_embedding(
|
|
q_host_ref,
|
|
rotary_cos_slice,
|
|
rotary_sin_slice,
|
|
is_rotary_interleaved,
|
|
q_host_ref_ro,
|
|
/*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask);
|
|
|
|
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < page_block_size)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); });
|
|
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
// clang-format off
|
|
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
|
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// copy Knew to the end of K
|
|
if(0 < seqlen_knew)
|
|
{
|
|
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
|
// clang-format off
|
|
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
|
|
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
|
|
// optionally apply RoPE to the knew_host_ref
|
|
auto* real_knew_host_ref = &knew_host_ref;
|
|
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
|
|
if(0 < rotary_dim)
|
|
{
|
|
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
|
|
|
auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(
|
|
rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
|
|
|
|
ck_tile::reference_batched_rotary_position_embedding(knew_host_ref,
|
|
rotary_cos_slice,
|
|
rotary_sin_slice,
|
|
is_rotary_interleaved,
|
|
knew_host_ref_ro.value());
|
|
|
|
real_knew_host_ref = &knew_host_ref_ro.value();
|
|
}
|
|
|
|
(*real_knew_host_ref).ForEach([&](auto& self, auto i) {
|
|
k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i);
|
|
});
|
|
}
|
|
#endif
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
if(0 < page_block_size)
|
|
{
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); });
|
|
// clang-format on
|
|
}
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
|
|
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
|
|
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
|
|
// clang-format on
|
|
}
|
|
}
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
// copy Vnew to the end of V
|
|
if(0 < seqlen_knew)
|
|
{
|
|
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
|
|
if(is_v_rowmajor)
|
|
{
|
|
// clang-format off
|
|
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
|
|
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
|
|
// clang-format on
|
|
}
|
|
else
|
|
{
|
|
// clang-format off
|
|
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
|
|
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
}
|
|
|
|
vnew_host_ref.ForEach([&](auto& self, auto i) {
|
|
v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i);
|
|
});
|
|
}
|
|
#endif
|
|
|
|
// reference
|
|
if constexpr(is_mx)
|
|
{
|
|
ck_tile::HostTensor<QScaleDataType> q_descale_host_ref(
|
|
{nhead, real_seqlen_q, hdim_q_scale});
|
|
ck_tile::HostTensor<KScaleDataType> k_descale_host_ref(
|
|
{nhead, real_seqlen_k, hdim_q_scale});
|
|
|
|
// clang-format off
|
|
if(i_perm) q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
|
else q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
|
// clang-format on
|
|
|
|
// clang-format off
|
|
if(i_perm) k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
|
else k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
|
// clang-format on
|
|
|
|
auto q_host_ref2 = ck_tile::reference_batched_mx_descale<QDataType,
|
|
QScaleDataType,
|
|
SaccDataType,
|
|
SaccDataType>(
|
|
q_host_ref, q_descale_host_ref, kQKScaleGranularity);
|
|
auto k_host_ref2 = ck_tile::reference_batched_mx_descale<KDataType,
|
|
KScaleDataType,
|
|
SaccDataType,
|
|
SaccDataType>(
|
|
k_host_ref, k_descale_host_ref, kQKScaleGranularity);
|
|
|
|
ck_tile::reference_batched_gemm<SaccDataType,
|
|
SaccDataType,
|
|
SaccDataType,
|
|
SMPLComputeDataType>(q_host_ref2,
|
|
k_host_ref2,
|
|
s_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
ck_tile::scales(scale_s_host));
|
|
}
|
|
else if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
const ck_tile::index_t q_offset =
|
|
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
|
|
const ck_tile::index_t k_offset =
|
|
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
|
ck_tile::reference_batched_quant_gemm<QDataType,
|
|
KDataType,
|
|
SaccDataType,
|
|
SMPLComputeDataType>(
|
|
q_host_ref,
|
|
k_host_ref,
|
|
s_host_ref,
|
|
ck_tile::idx_identity{},
|
|
ck_tile::idx_identity{},
|
|
[&](auto idx, auto value) {
|
|
return value * scale_s *
|
|
q_descale_host(b_idx,
|
|
std::get<0>(idx),
|
|
q_offset + std::get<1>(idx) / block_scale_size_q_) *
|
|
k_descale_host(b_idx,
|
|
std::get<0>(idx) / nr,
|
|
k_offset + std::get<2>(idx) / block_scale_size_kv_);
|
|
});
|
|
}
|
|
else
|
|
{
|
|
ck_tile::
|
|
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
|
q_host_ref,
|
|
k_host_ref,
|
|
s_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
ck_tile::scales(scale_s_host));
|
|
}
|
|
|
|
if(0.f < logits_soft_cap)
|
|
{
|
|
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
|
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
|
return ck_tile::type_convert<SaccDataType>(
|
|
logits_soft_cap *
|
|
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
|
});
|
|
}
|
|
|
|
if(bias.type == bias_enum::elementwise_bias)
|
|
{
|
|
// elementwise bias
|
|
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
|
// clang-format off
|
|
if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
|
else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
|
// clang-format on
|
|
|
|
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
|
// real_seqlen_k]
|
|
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
|
BiasDataType,
|
|
SMPLComputeDataType,
|
|
SMPLComputeDataType>(
|
|
s_host_ref, bias_host_ref, s_host_ref);
|
|
}
|
|
else if(bias.type == bias_enum::alibi)
|
|
{
|
|
// alibi construct elementwise bias to verify
|
|
auto alibi_host = [&]() {
|
|
if(mask.type != mask_enum::no_mask)
|
|
{
|
|
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
|
|
0,
|
|
mask.left,
|
|
mask.right,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
|
}
|
|
else
|
|
{
|
|
return ck_tile::Alibi<SaccDataType, true>{
|
|
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
|
}
|
|
}();
|
|
|
|
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
|
for(auto i_h = 0; i_h < nhead; i_h++)
|
|
{
|
|
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
|
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL
|
|
? current_slope
|
|
: -current_slope;
|
|
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
|
{
|
|
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
|
{
|
|
SaccDataType pixel = 0;
|
|
alibi_host.update(pixel, i_r, i_c);
|
|
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
|
}
|
|
}
|
|
}
|
|
// [nhead, real_seqlen_q, real_seqlen_k]
|
|
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
|
SaccDataType,
|
|
SMPLComputeDataType,
|
|
SMPLComputeDataType>(
|
|
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
|
}
|
|
|
|
if(mask.type == mask_enum::no_mask)
|
|
{
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
|
}
|
|
else if(mask.type == mask_enum::window_generic)
|
|
{
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
|
mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k));
|
|
}
|
|
else
|
|
{
|
|
// if left window size is negative, means causal
|
|
// else means generic (for current batch)
|
|
if(mask.left < 0)
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
|
mask.left,
|
|
mask.right,
|
|
mask.sink,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
mask.type == mask_enum::mask_top_left));
|
|
else
|
|
ck_tile::reference_batched_masking<SaccDataType>(
|
|
s_host_ref,
|
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
|
mask.left,
|
|
mask.right,
|
|
mask.sink,
|
|
real_seqlen_q,
|
|
real_seqlen_k,
|
|
mask.type == mask_enum::mask_top_left));
|
|
}
|
|
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
|
if(init_sink_value != 0)
|
|
{
|
|
// Create extended tensor with sink token
|
|
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
|
|
|
// Copy original attention scores and append sink values
|
|
copy_attention_scores_with_sink(
|
|
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
|
|
|
// Compute softmax on extended tensor
|
|
ck_tile::HostTensor<PDataType> p_extended(
|
|
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
|
|
|
if(lse)
|
|
{
|
|
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
|
SMPLComputeDataType,
|
|
PDataType>(
|
|
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
|
}
|
|
else
|
|
{
|
|
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
|
SMPLComputeDataType,
|
|
PDataType>(
|
|
s_with_sinks_ref, p_extended, p_compute_element_func);
|
|
}
|
|
|
|
// Extract only the original columns (exclude sink token column)
|
|
p_host_ref.ForEach(
|
|
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
|
}
|
|
else
|
|
{
|
|
// No sink tokens - compute softmax directly
|
|
if(lse)
|
|
{
|
|
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
|
SMPLComputeDataType,
|
|
PDataType>(
|
|
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
|
}
|
|
else
|
|
{
|
|
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
|
SMPLComputeDataType,
|
|
PDataType>(
|
|
s_host_ref, p_host_ref, p_compute_element_func);
|
|
}
|
|
}
|
|
if(lse)
|
|
{
|
|
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
|
lse_host_result.ForEach([&](auto& self, auto idx) {
|
|
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
|
});
|
|
|
|
// Use smaller rtol/atol as LSE is computed and stored in fp32, so there is no
|
|
// precision loss due to conversion
|
|
bool cur_pass = ck_tile::check_err(lse_host_result,
|
|
lse_host_ref,
|
|
"LSE Error: Incorrect results!",
|
|
1e-4,
|
|
1e-4,
|
|
/* allow_infinity_ref = */ true);
|
|
|
|
pass &= cur_pass;
|
|
if(!cur_pass)
|
|
{
|
|
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
|
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
|
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
|
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
|
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
|
}
|
|
}
|
|
if(p_drop > 0)
|
|
{
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
ck_tile::reference_batched_dropout_randval(
|
|
randval_host_ref, wb, drop_seed, drop_offset);
|
|
ck_tile::reference_batched_dropout(
|
|
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
|
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
randval_host_result.ForEach([&](auto& self, const auto& idx) {
|
|
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
|
|
});
|
|
masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) {
|
|
// Ignore all masked values in validation check
|
|
if(std::isinf(self(idx)))
|
|
{
|
|
randval_host_ref(idx) = 0;
|
|
randval_host_result(idx) = 0;
|
|
}
|
|
});
|
|
bool cur_pass = ck_tile::check_err(randval_host_result,
|
|
randval_host_ref,
|
|
"DROPOUT RANDVAL Error: Incorrect results!");
|
|
pass &= cur_pass;
|
|
}
|
|
|
|
if constexpr(is_mx)
|
|
{
|
|
const ck_tile::index_t real_seqlen_v_scale =
|
|
seqstart_v_scale_host[wb + 1] - seqstart_v_scale_host[wb];
|
|
const ck_tile::index_t v_scale_offset =
|
|
mode == mode_enum::batch ? 0 : seqstart_v_scale_host[wb];
|
|
|
|
ck_tile::HostTensor<VScaleDataType> v_descale_host_ref(
|
|
{nhead, hdim_v, real_seqlen_v_scale});
|
|
|
|
// clang-format off
|
|
if(i_perm) v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[0] / nr, i[1], i[2] + v_scale_offset); });
|
|
else v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[1], i[0] / nr, i[2] + v_scale_offset); });
|
|
// clang-format on
|
|
|
|
auto v_host_ref2 = ck_tile::reference_batched_mx_descale<VDataType,
|
|
VScaleDataType,
|
|
OaccDataType,
|
|
OaccDataType>(
|
|
v_host_ref, v_descale_host_ref, kVScaleGranularity);
|
|
|
|
// P is not quantized and then dequantized here (PDataType = float).
|
|
// On host softmax is computed for the whole row of S, while on device FA computes
|
|
// softmax and quantizes it in blocks of N0 values. Quantization on host would make
|
|
// reference results **less** precise than the device results for large seqlen_k!
|
|
|
|
ck_tile::reference_batched_gemm<PDataType, OaccDataType, OaccDataType, ODataType>(
|
|
p_host_ref,
|
|
v_host_ref2,
|
|
o_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
oacc_element_func);
|
|
}
|
|
else if(qscale.type == quant_scale_enum::blockscale)
|
|
{
|
|
const ck_tile::index_t v_offset =
|
|
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
|
ck_tile::
|
|
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
|
p_host_ref,
|
|
v_host_ref,
|
|
o_host_ref,
|
|
ck_tile::idx_identity{},
|
|
[&](auto idx, auto value) {
|
|
return ck_tile::type_convert<float>(value) *
|
|
v_descale_host(b_idx,
|
|
std::get<0>(idx) / nr,
|
|
v_offset +
|
|
std::get<2>(idx) / block_scale_size_kv_);
|
|
},
|
|
ck_tile::idx_identity{});
|
|
}
|
|
else
|
|
{
|
|
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
|
p_host_ref,
|
|
v_host_ref,
|
|
o_host_ref,
|
|
ck_tile::identity{},
|
|
ck_tile::identity{},
|
|
oacc_element_func);
|
|
}
|
|
|
|
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
|
// clang-format off
|
|
// permute
|
|
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
|
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
|
// clang-format on
|
|
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
|
bool cur_pass = ck_tile::check_err(o_host_result,
|
|
o_host_ref,
|
|
std::string("OUT Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
pass &= cur_pass;
|
|
if(!cur_pass)
|
|
{
|
|
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
|
|
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
|
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
|
<< "\tseqstart_q (logical): " << seqstart_q_host << std::endl
|
|
<< "\tseqstart_q (physical): " << seqstart_q_with_padding_host
|
|
<< std::endl
|
|
<< "\tseqstart_k (logical): " << seqstart_k_host << std::endl
|
|
<< "\tseqstart_k (physical): " << seqstart_k_with_padding_host
|
|
<< std::endl
|
|
<< "\tquery_offset used: " << query_offset << std::endl
|
|
<< "\tkey_offset used: " << key_offset << std::endl;
|
|
}
|
|
|
|
if(!pass)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
|
|
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
|
}
|
|
|
|
if(json)
|
|
{
|
|
const std::string qscale_name =
|
|
(qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
|
: qscale.type == quant_scale_enum::pertensor ? "pertensor"
|
|
: qscale.type == quant_scale_enum::blockscale ? "blockscale"
|
|
: qscale.type == quant_scale_enum::kv_blockscale ? "kv_blockscale"
|
|
: qscale.type == quant_scale_enum::mx ? "mx"
|
|
: "unknown");
|
|
dump_fmha_fwd_json_results(*json,
|
|
data_type,
|
|
mode == mode_enum::batch ? "batch" : "group",
|
|
io_layout(i_perm, o_perm),
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs[0],
|
|
seqlen_ks[0],
|
|
seqlen_kpads[0],
|
|
hdim_q,
|
|
hdim_v,
|
|
scale_s,
|
|
p_drop,
|
|
lse,
|
|
qscale_name,
|
|
bias.type == bias_enum::elementwise_bias
|
|
? "elementwise_bias"
|
|
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
|
is_v_rowmajor ? "r" : "c",
|
|
pass,
|
|
ave_time,
|
|
tflops,
|
|
gb_per_sec);
|
|
}
|
|
|
|
return pass ? fwd_result::success : fwd_result::failure;
|
|
}
|