mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
[CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation
* use different seed for each input (Q, K, V...);
* use deterministic generation of:
* seqstart_q/k (for group mode);
* block_table (for paged-kvcahe);
* cache_batch_idx (for kvcache);
* Extract arg_parser-related code from run functions to use them as tests
* Split examples into main programs and fmha runners, build instances separately
* Add dummy tests that use instances and runners
* Fix a missed corner case of f32->f8 conversion
When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.
* Fix incorrect fp8 scales for P and O in validation code
DataTypeConfig was incorrectly compared with fp8_t.
* Add host generation of dropout random values and use it for validation
Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.
* Implement tests from smoke_test_bwd.sh
* Return result as enum to distinguish failure and missing instance
* Add tests for bwd features: bias, alibi, dropout
* Implement tests from smoke_test_fwd.sh
* Pass seqlen_q/k as vectors to fwd and bwd runners
* Add tests for fwd features: bias, alibi, dropout
* Add tests for pagedkv and splitkv
* Fix conditions when to use splitkv and pagedkv kernels
splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.
* Add tests for appendkv
* Use is_v_rowmajor = true because there are no instances with column layout anymore
* Split public and private compile options for instances
Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.
* Improve parsing validation in bias and mask
* Pass bias as string for consistency with mask
* Catch parsing and other exceptions
* Add bwd test for deterministic flag
* Initialize fp8 tensors (-init=ufq) similarly to uf
* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null
seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).
* Fix use_pagedkv value when kvcache = true but page_block_size = 0
In this case block_table_ptr is nullptr which is accessed in the kernel.
* Clean up bwd tests
* Unify fwd tests for f16/bf16 and fp8
* Use better explicit instantiation declaration for fmha_bwd<2>
* Use the same seed for all tests, allow to override it with env variable
* Undo clang-format of one irrelevant file
For some reason my local clang-format-18 and the one in CI work differently.
* Do not build instances and tests on unsupported archs
* Build instance libraries as OBJECT library
* CI: Enable sccache for HIP
There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache
* Add tests to REGRESSION_TESTS
* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0
The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.
* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options
The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.
* Fix formatting and names
[ROCm/composable_kernel commit: ec006bb8e0]
This commit is contained in:
@@ -399,9 +399,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
if(exponent_diff > DstT_mant + 1)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RandValOutputDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
index_t batch,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
|
||||
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
|
||||
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
|
||||
|
||||
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
|
||||
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
|
||||
// different warp gemms (16x16 or 32x32).
|
||||
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
|
||||
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
|
||||
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
|
||||
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
|
||||
// C j: (lane % 32)
|
||||
// With SFactor = 2 it becomes:
|
||||
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
|
||||
// C j: (lane % 32)
|
||||
|
||||
constexpr index_t max_warp_size = 64;
|
||||
constexpr index_t warp_gemm_mn = 32;
|
||||
|
||||
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
|
||||
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
|
||||
|
||||
auto f = [&](index_t i_h, index_t row, index_t col) {
|
||||
uint2 rowcol = make_uint2(row, col);
|
||||
for(index_t lane = 0; lane < max_warp_size; lane++)
|
||||
{
|
||||
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
|
||||
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
for(auto r = 0; r < 16; r++)
|
||||
{
|
||||
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
|
||||
index_t j = (lane % 32);
|
||||
index_t m = row * warp_gemm_mn + i;
|
||||
index_t n = col * warp_gemm_mn + j;
|
||||
|
||||
if(m < real_seqlen_q && n < real_seqlen_k)
|
||||
{
|
||||
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -611,7 +611,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
@@ -636,7 +636,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
|
||||
Reference in New Issue
Block a user