mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation * use different seed for each input (Q, K, V...); * use deterministic generation of: * seqstart_q/k (for group mode); * block_table (for paged-kvcahe); * cache_batch_idx (for kvcache); * Extract arg_parser-related code from run functions to use them as tests * Split examples into main programs and fmha runners, build instances separately * Add dummy tests that use instances and runners * Fix a missed corner case of f32->f8 conversion When value if < min f8 denormal but > min f8 denormal / 2, it must be rounded to min f8 denormal (i.e. 0b1), not to 0. * Fix incorrect fp8 scales for P and O in validation code DataTypeConfig was incorrectly compared with fp8_t. * Add host generation of dropout random values and use it for validation Previously host validation (reference_batched_dropout) used random numbers generated by BlockDropout of the kernel, meaning that incorrect generation on device (bad distribution, repeated numbers, too many zeros, etc.) would not trigger any validation errors. * Implement tests from smoke_test_bwd.sh * Return result as enum to distinguish failure and missing instance * Add tests for bwd features: bias, alibi, dropout * Implement tests from smoke_test_fwd.sh * Pass seqlen_q/k as vectors to fwd and bwd runners * Add tests for fwd features: bias, alibi, dropout * Add tests for pagedkv and splitkv * Fix conditions when to use splitkv and pagedkv kernels splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size). In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1. In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance. * Add tests for appendkv * Use is_v_rowmajor = true because there are no instances with column layout anymore * Split public and private compile options for instances Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API. * Improve parsing validation in bias and mask * Pass bias as string for consistency with mask * Catch parsing and other exceptions * Add bwd test for deterministic flag * Initialize fp8 tensors (-init=ufq) similarly to uf * Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null seqlen_k cannot be used to determine padding when seqlen_k_ptr is provided. The actual seqlen_k is taken from seqlen_k_ptr[b]. Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr may contain arbitrary values. In the example or tests this produces incorrect results with appendkv (for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8). * Fix use_pagedkv value when kvcache = true but page_block_size = 0 In this case block_table_ptr is nullptr which is accessed in the kernel. * Clean up bwd tests * Unify fwd tests for f16/bf16 and fp8 * Use better explicit instantiation declaration for fmha_bwd<2> * Use the same seed for all tests, allow to override it with env variable * Undo clang-format of one irrelevant file For some reason my local clang-format-18 and the one in CI work differently. * Do not build instances and tests on unsupported archs * Build instance libraries as OBJECT library * CI: Enable sccache for HIP There are source files with LANGUAGE HIP, they need -DCMAKE_HIP_COMPILER_LAUNCHER=sccache * Add tests to REGRESSION_TESTS * Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0 The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf. * Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options The instances don't actually depend on them, only examples and tests do. Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS without recompiling instances that are already in ccache. * Fix formatting and names
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
@@ -28,6 +27,23 @@ std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
@@ -39,12 +55,13 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1, // if not negative, clamp min
|
||||
int32_t seqlen_max = -1, // if not negative, clamp max
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
@@ -58,7 +75,6 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
@@ -89,43 +105,31 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
std::vector<int32_t> generate_seqstarts(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1,
|
||||
int32_t seqlen_max = -1,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int>
|
||||
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
|
||||
-> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(engine);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator>
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
-> std::enable_if_t<std::is_integral_v<Int>>
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(engine); });
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* decode the seqlen string from cmdline
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
@@ -136,23 +140,23 @@ auto randints(ForwardIterator first,
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
decode_seqlen(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
std::string q_val,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min = 0,
|
||||
bool need_append_kvcache = false,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = _S2I_(q_val);
|
||||
ck_tile::index_t k = _S2I_(k_val);
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
@@ -166,7 +170,7 @@ decode_seqlen(mode_enum mode,
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
seed);
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
@@ -187,25 +191,19 @@ decode_seqlen(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t idx = 0;
|
||||
std::string::size_type pos_q = 0;
|
||||
std::string::size_type pos_k = 0;
|
||||
std::string::size_type pos_kp = 0;
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
while(true)
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
auto found_q = q_val.find(',', pos_q);
|
||||
auto found_k = k_val.find(',', pos_k);
|
||||
auto found_kp = k_pad_val.find(',', pos_kp);
|
||||
|
||||
ck_tile::index_t q = _S2I_(
|
||||
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
|
||||
ck_tile::index_t k = _S2I_(
|
||||
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
|
||||
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
|
||||
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
@@ -219,21 +217,13 @@ decode_seqlen(mode_enum mode,
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
idx++;
|
||||
if(found_q == std::string::npos || idx >= batch)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos_q = found_q + 1;
|
||||
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
|
||||
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
|
||||
auto rem_k =
|
||||
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
@@ -241,26 +231,14 @@ decode_seqlen(mode_enum mode,
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
int env_get_int(const char* var_name, int default_int)
|
||||
{
|
||||
char* v = getenv(var_name);
|
||||
int r = default_int;
|
||||
if(v)
|
||||
r = std::atoi(v);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int>
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::shuffle(first, last, engine);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user