mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
* Improve random number generation
* use different seed for each input (Q, K, V...);
* use deterministic generation of:
* seqstart_q/k (for group mode);
* block_table (for paged-kvcahe);
* cache_batch_idx (for kvcache);
* Extract arg_parser-related code from run functions to use them as tests
* Split examples into main programs and fmha runners, build instances separately
* Add dummy tests that use instances and runners
* Fix a missed corner case of f32->f8 conversion
When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.
* Fix incorrect fp8 scales for P and O in validation code
DataTypeConfig was incorrectly compared with fp8_t.
* Add host generation of dropout random values and use it for validation
Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.
* Implement tests from smoke_test_bwd.sh
* Return result as enum to distinguish failure and missing instance
* Add tests for bwd features: bias, alibi, dropout
* Implement tests from smoke_test_fwd.sh
* Pass seqlen_q/k as vectors to fwd and bwd runners
* Add tests for fwd features: bias, alibi, dropout
* Add tests for pagedkv and splitkv
* Fix conditions when to use splitkv and pagedkv kernels
splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.
* Add tests for appendkv
* Use is_v_rowmajor = true because there are no instances with column layout anymore
* Split public and private compile options for instances
Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.
* Improve parsing validation in bias and mask
* Pass bias as string for consistency with mask
* Catch parsing and other exceptions
* Add bwd test for deterministic flag
* Initialize fp8 tensors (-init=ufq) similarly to uf
* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null
seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).
* Fix use_pagedkv value when kvcache = true but page_block_size = 0
In this case block_table_ptr is nullptr which is accessed in the kernel.
* Clean up bwd tests
* Unify fwd tests for f16/bf16 and fp8
* Use better explicit instantiation declaration for fmha_bwd<2>
* Use the same seed for all tests, allow to override it with env variable
* Undo clang-format of one irrelevant file
For some reason my local clang-format-18 and the one in CI work differently.
* Do not build instances and tests on unsupported archs
* Build instance libraries as OBJECT library
* CI: Enable sccache for HIP
There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache
* Add tests to REGRESSION_TESTS
* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0
The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.
* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options
The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.
* Fix formatting and names
[ROCm/composable_kernel commit: ec006bb8e0]
115 lines
3.0 KiB
C++
115 lines
3.0 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
// keep sync with BlockAttentionBiasEnum
|
|
enum class bias_enum
|
|
{
|
|
no_bias = 0,
|
|
elementwise_bias = 1,
|
|
alibi = 2,
|
|
};
|
|
|
|
struct bias_info
|
|
{
|
|
bias_enum type;
|
|
/*
|
|
* simple dispatch logic
|
|
*
|
|
* if type == elementwise_bias:
|
|
* if rank_info == 0:
|
|
* bias is 1*1*s*s
|
|
* elif rank_info == 1:
|
|
* bias is 1*h*s*s
|
|
* elif rank_info == 2:
|
|
* bias is b*h*s*s
|
|
*
|
|
* elif type == alibi:
|
|
* if rank_info == 0:
|
|
* alibi in 1*h
|
|
* elif rank_info == 1:
|
|
* alibi in b*h
|
|
*/
|
|
int rank_info;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == bias_enum::no_bias)
|
|
os << "n";
|
|
else if(type == bias_enum::elementwise_bias)
|
|
{
|
|
os << "e";
|
|
if(rank_info != 0)
|
|
{
|
|
os << "[" << rank_info << "]";
|
|
}
|
|
}
|
|
else if(type == bias_enum::alibi)
|
|
{
|
|
os << "alibi";
|
|
if(rank_info != 0)
|
|
{
|
|
os << "[" << rank_info << "]";
|
|
}
|
|
}
|
|
}
|
|
|
|
static bias_info decode(std::string str)
|
|
{
|
|
bias_info info{bias_enum::no_bias, 0};
|
|
auto found_0 = str.find(':');
|
|
if(found_0 != std::string::npos)
|
|
{
|
|
std::string t = str.substr(0, found_0);
|
|
std::string v = str.substr(found_0 + 1);
|
|
if(t == "e" || t == "elementwise")
|
|
{
|
|
info.type = bias_enum::elementwise_bias;
|
|
info.rank_info = std::stoi(v);
|
|
if(info.rank_info < 0 || info.rank_info > 2)
|
|
throw std::invalid_argument("invalid bias rank: " + str);
|
|
}
|
|
else if(t == "a" || t == "alibi")
|
|
{
|
|
info.type = bias_enum::alibi;
|
|
info.rank_info = std::stoi(v);
|
|
if(info.rank_info < 0 || info.rank_info > 1)
|
|
throw std::invalid_argument("invalid bias rank: " + str);
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid bias value: " + str);
|
|
}
|
|
}
|
|
else if(str == "0" || str == "n")
|
|
{
|
|
info.type = bias_enum::no_bias;
|
|
}
|
|
else if(str == "1" || str == "e" || str == "elementwise")
|
|
{
|
|
info.type = bias_enum::elementwise_bias;
|
|
}
|
|
else if(str == "2" || str == "a" || str == "alibi")
|
|
{
|
|
info.type = bias_enum::alibi;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid bias value: " + str);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
|
{
|
|
bi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|