mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* Improve random number generation * use different seed for each input (Q, K, V...); * use deterministic generation of: * seqstart_q/k (for group mode); * block_table (for paged-kvcahe); * cache_batch_idx (for kvcache); * Extract arg_parser-related code from run functions to use them as tests * Split examples into main programs and fmha runners, build instances separately * Add dummy tests that use instances and runners * Fix a missed corner case of f32->f8 conversion When value if < min f8 denormal but > min f8 denormal / 2, it must be rounded to min f8 denormal (i.e. 0b1), not to 0. * Fix incorrect fp8 scales for P and O in validation code DataTypeConfig was incorrectly compared with fp8_t. * Add host generation of dropout random values and use it for validation Previously host validation (reference_batched_dropout) used random numbers generated by BlockDropout of the kernel, meaning that incorrect generation on device (bad distribution, repeated numbers, too many zeros, etc.) would not trigger any validation errors. * Implement tests from smoke_test_bwd.sh * Return result as enum to distinguish failure and missing instance * Add tests for bwd features: bias, alibi, dropout * Implement tests from smoke_test_fwd.sh * Pass seqlen_q/k as vectors to fwd and bwd runners * Add tests for fwd features: bias, alibi, dropout * Add tests for pagedkv and splitkv * Fix conditions when to use splitkv and pagedkv kernels splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size). In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1. In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance. * Add tests for appendkv * Use is_v_rowmajor = true because there are no instances with column layout anymore * Split public and private compile options for instances Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API. * Improve parsing validation in bias and mask * Pass bias as string for consistency with mask * Catch parsing and other exceptions * Add bwd test for deterministic flag * Initialize fp8 tensors (-init=ufq) similarly to uf * Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null seqlen_k cannot be used to determine padding when seqlen_k_ptr is provided. The actual seqlen_k is taken from seqlen_k_ptr[b]. Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr may contain arbitrary values. In the example or tests this produces incorrect results with appendkv (for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8). * Fix use_pagedkv value when kvcache = true but page_block_size = 0 In this case block_table_ptr is nullptr which is accessed in the kernel. * Clean up bwd tests * Unify fwd tests for f16/bf16 and fp8 * Use better explicit instantiation declaration for fmha_bwd<2> * Use the same seed for all tests, allow to override it with env variable * Undo clang-format of one irrelevant file For some reason my local clang-format-18 and the one in CI work differently. * Do not build instances and tests on unsupported archs * Build instance libraries as OBJECT library * CI: Enable sccache for HIP There are source files with LANGUAGE HIP, they need -DCMAKE_HIP_COMPILER_LAUNCHER=sccache * Add tests to REGRESSION_TESTS * Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0 The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf. * Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options The instances don't actually depend on them, only examples and tests do. Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS without recompiling instances that are already in ccache. * Fix formatting and names
168 lines
5.5 KiB
C++
168 lines
5.5 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
|
enum class mask_enum
|
|
{
|
|
no_mask = 0,
|
|
mask_top_left,
|
|
mask_bottom_right,
|
|
window_generic,
|
|
};
|
|
|
|
struct mask_info
|
|
{
|
|
mask_enum type;
|
|
ck_tile::index_t seqlen_q;
|
|
ck_tile::index_t seqlen_k;
|
|
ck_tile::index_t y, x;
|
|
ck_tile::index_t left, right; // FA style SWA left/right
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == mask_enum::no_mask)
|
|
os << "n";
|
|
else if(type == mask_enum::mask_top_left)
|
|
os << "t(" << left << ":" << right << ")";
|
|
else if(type == mask_enum::mask_bottom_right)
|
|
os << "b(" << left << ":" << right << ")";
|
|
else
|
|
{
|
|
os << "g(" << y << ":" << x << ")";
|
|
}
|
|
}
|
|
|
|
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
|
{
|
|
ck_tile::index_t x_total = seqlen_k;
|
|
ck_tile::index_t y_total = seqlen_q;
|
|
mask_info tmp;
|
|
tmp.seqlen_q = seqlen_q;
|
|
tmp.seqlen_k = seqlen_k;
|
|
auto found_0 = str.find(':');
|
|
if(found_0 != std::string::npos)
|
|
{
|
|
std::string t = str.substr(0, found_0);
|
|
std::string v = str.substr(found_0 + 1);
|
|
if(t == "xt" || t == "xb")
|
|
{
|
|
// xformer style sliding window attn from top-left
|
|
ck_tile::index_t window_size = std::stoi(v);
|
|
ck_tile::index_t left_size = -1;
|
|
ck_tile::index_t right_size = 0;
|
|
if(window_size > 0)
|
|
{
|
|
left_size = window_size / 2;
|
|
right_size = window_size - 1 - left_size;
|
|
}
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
left_size, right_size, y_total, x_total, t == "xt");
|
|
|
|
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = left_size;
|
|
tmp.right = right_size;
|
|
}
|
|
else if(t == "t" || t == "b" || t == "g")
|
|
{
|
|
auto found_1 = v.find(",");
|
|
if(found_1 == std::string::npos)
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
|
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
|
if(t == "t")
|
|
{
|
|
tmp.type = mask_enum::mask_top_left;
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
v0, v1, y_total, x_total, true);
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = v0;
|
|
tmp.right = v1;
|
|
}
|
|
else if(t == "b")
|
|
{
|
|
tmp.type = mask_enum::mask_bottom_right;
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
v0, v1, y_total, x_total, false);
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = v0;
|
|
tmp.right = v1;
|
|
}
|
|
else if(t == "g")
|
|
{
|
|
tmp.type = mask_enum::window_generic;
|
|
tmp.y = v0;
|
|
tmp.x = v1;
|
|
tmp.left = v0; // TODO: don't use this?
|
|
tmp.right = v1;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
}
|
|
else if(str == "0")
|
|
{
|
|
tmp.type = mask_enum::no_mask;
|
|
}
|
|
else if(str == "1" || str == "t")
|
|
{
|
|
tmp.type = mask_enum::mask_top_left;
|
|
tmp.y = seqlen_q;
|
|
tmp.x = 1;
|
|
tmp.left = -1;
|
|
tmp.right = 0;
|
|
}
|
|
else if(str == "2" || str == "b")
|
|
{
|
|
tmp.type = mask_enum::mask_bottom_right;
|
|
tmp.y = seqlen_q;
|
|
tmp.x = seqlen_k - seqlen_q + 1;
|
|
tmp.left = -1;
|
|
tmp.right = 0;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
return tmp;
|
|
}
|
|
|
|
ck_tile::index_t get_unmaskarea() const
|
|
{
|
|
if(type == mask_enum::no_mask)
|
|
return seqlen_q * seqlen_k;
|
|
ck_tile::index_t area = 0;
|
|
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
|
{
|
|
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
|
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
|
if(x_end > x_start)
|
|
{
|
|
area += (x_end - x_start);
|
|
}
|
|
}
|
|
return area;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
|
{
|
|
mi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|