mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation * use different seed for each input (Q, K, V...); * use deterministic generation of: * seqstart_q/k (for group mode); * block_table (for paged-kvcahe); * cache_batch_idx (for kvcache); * Extract arg_parser-related code from run functions to use them as tests * Split examples into main programs and fmha runners, build instances separately * Add dummy tests that use instances and runners * Fix a missed corner case of f32->f8 conversion When value if < min f8 denormal but > min f8 denormal / 2, it must be rounded to min f8 denormal (i.e. 0b1), not to 0. * Fix incorrect fp8 scales for P and O in validation code DataTypeConfig was incorrectly compared with fp8_t. * Add host generation of dropout random values and use it for validation Previously host validation (reference_batched_dropout) used random numbers generated by BlockDropout of the kernel, meaning that incorrect generation on device (bad distribution, repeated numbers, too many zeros, etc.) would not trigger any validation errors. * Implement tests from smoke_test_bwd.sh * Return result as enum to distinguish failure and missing instance * Add tests for bwd features: bias, alibi, dropout * Implement tests from smoke_test_fwd.sh * Pass seqlen_q/k as vectors to fwd and bwd runners * Add tests for fwd features: bias, alibi, dropout * Add tests for pagedkv and splitkv * Fix conditions when to use splitkv and pagedkv kernels splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size). In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1. In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance. * Add tests for appendkv * Use is_v_rowmajor = true because there are no instances with column layout anymore * Split public and private compile options for instances Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API. * Improve parsing validation in bias and mask * Pass bias as string for consistency with mask * Catch parsing and other exceptions * Add bwd test for deterministic flag * Initialize fp8 tensors (-init=ufq) similarly to uf * Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null seqlen_k cannot be used to determine padding when seqlen_k_ptr is provided. The actual seqlen_k is taken from seqlen_k_ptr[b]. Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr may contain arbitrary values. In the example or tests this produces incorrect results with appendkv (for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8). * Fix use_pagedkv value when kvcache = true but page_block_size = 0 In this case block_table_ptr is nullptr which is accessed in the kernel. * Clean up bwd tests * Unify fwd tests for f16/bf16 and fp8 * Use better explicit instantiation declaration for fmha_bwd<2> * Use the same seed for all tests, allow to override it with env variable * Undo clang-format of one irrelevant file For some reason my local clang-format-18 and the one in CI work differently. * Do not build instances and tests on unsupported archs * Build instance libraries as OBJECT library * CI: Enable sccache for HIP There are source files with LANGUAGE HIP, they need -DCMAKE_HIP_COMPILER_LAUNCHER=sccache * Add tests to REGRESSION_TESTS * Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0 The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf. * Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options The instances don't actually depend on them, only examples and tests do. Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS without recompiling instances that are already in ccache. * Fix formatting and names
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RandValOutputDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
index_t batch,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
|
||||
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
|
||||
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
|
||||
|
||||
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
|
||||
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
|
||||
// different warp gemms (16x16 or 32x32).
|
||||
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
|
||||
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
|
||||
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
|
||||
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
|
||||
// C j: (lane % 32)
|
||||
// With SFactor = 2 it becomes:
|
||||
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
|
||||
// C j: (lane % 32)
|
||||
|
||||
constexpr index_t max_warp_size = 64;
|
||||
constexpr index_t warp_gemm_mn = 32;
|
||||
|
||||
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
|
||||
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
|
||||
|
||||
auto f = [&](index_t i_h, index_t row, index_t col) {
|
||||
uint2 rowcol = make_uint2(row, col);
|
||||
for(index_t lane = 0; lane < max_warp_size; lane++)
|
||||
{
|
||||
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
|
||||
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
for(auto r = 0; r < 16; r++)
|
||||
{
|
||||
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
|
||||
index_t j = (lane % 32);
|
||||
index_t m = row * warp_gemm_mn + i;
|
||||
index_t n = col * warp_gemm_mn + j;
|
||||
|
||||
if(m < real_seqlen_q && n < real_seqlen_k)
|
||||
{
|
||||
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user