[CK_TILE] Add gtests for FMHA (#2744)

* Improve random number generation

* use different seed for each input (Q, K, V...);
* use deterministic generation of:
  * seqstart_q/k (for group mode);
  * block_table (for paged-kvcahe);
  * cache_batch_idx (for kvcache);

* Extract arg_parser-related code from run functions to use them as tests

* Split examples into main programs and fmha runners, build instances separately

* Add dummy tests that use instances and runners

* Fix a missed corner case of f32->f8 conversion

When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.

* Fix incorrect fp8 scales for P and O in validation code

DataTypeConfig was incorrectly compared with fp8_t.

* Add host generation of dropout random values and use it for validation

Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.

* Implement tests from smoke_test_bwd.sh

* Return result as enum to distinguish failure and missing instance

* Add tests for bwd features: bias, alibi, dropout

* Implement tests from smoke_test_fwd.sh

* Pass seqlen_q/k as vectors to fwd and bwd runners

* Add tests for fwd features: bias, alibi, dropout

* Add tests for pagedkv and splitkv

* Fix conditions when to use splitkv and pagedkv kernels

splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.

* Add tests for appendkv

* Use is_v_rowmajor = true because there are no instances with column layout anymore

* Split public and private compile options for instances

Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.

* Improve parsing validation in bias and mask

* Pass bias as string for consistency with mask

* Catch parsing and other exceptions

* Add bwd test for deterministic flag

* Initialize fp8 tensors (-init=ufq) similarly to uf

* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null

seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).

* Fix use_pagedkv value when kvcache = true but page_block_size = 0

In this case block_table_ptr is nullptr which is accessed in the kernel.

* Clean up bwd tests

* Unify fwd tests for f16/bf16 and fp8

* Use better explicit instantiation declaration for fmha_bwd<2>

* Use the same seed for all tests, allow to override it with env variable

* Undo clang-format of one irrelevant file

For some reason my local clang-format-18 and the one in CI work differently.

* Do not build instances and tests on unsupported archs

* Build instance libraries as OBJECT library

* CI: Enable sccache for HIP

There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache

* Add tests to REGRESSION_TESTS

* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0

The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.

* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options

The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.

* Fix formatting and names
This commit is contained in:
Anton Gorenko
2025-09-10 09:06:14 +06:00
committed by GitHub
parent c254f3d7b4
commit ec006bb8e0
27 changed files with 2429 additions and 865 deletions

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename RandValOutputDataType>
CK_TILE_HOST void
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
index_t batch,
uint64_t drop_seed,
uint64_t drop_offset)
{
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
// different warp gemms (16x16 or 32x32).
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
// C j: (lane % 32)
// With SFactor = 2 it becomes:
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
// C j: (lane % 32)
constexpr index_t max_warp_size = 64;
constexpr index_t warp_gemm_mn = 32;
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
auto f = [&](index_t i_h, index_t row, index_t col) {
uint2 rowcol = make_uint2(row, col);
for(index_t lane = 0; lane < max_warp_size; lane++)
{
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
for(auto r = 0; r < 16; r++)
{
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
index_t j = (lane % 32);
index_t m = row * warp_gemm_mn + i;
index_t n = col * warp_gemm_mn + j;
if(m < real_seqlen_q && n < real_seqlen_k)
{
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
}
}
}
};
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
}
} // namespace ck_tile