mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE][FMHA] Support gfx11 ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
204 lines
6.9 KiB
C++
204 lines
6.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
|
enum class mask_enum
|
|
{
|
|
no_mask = 0,
|
|
mask_top_left,
|
|
mask_bottom_right,
|
|
window_generic,
|
|
};
|
|
|
|
struct mask_info
|
|
{
|
|
mask_enum type;
|
|
ck_tile::index_t seqlen_q;
|
|
ck_tile::index_t seqlen_k;
|
|
ck_tile::index_t y, x;
|
|
ck_tile::index_t left, right; // FA style SWA left/right
|
|
ck_tile::index_t sink;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == mask_enum::no_mask)
|
|
os << "n";
|
|
else if(type == mask_enum::mask_top_left)
|
|
os << "t(" << left << ":" << right << ")";
|
|
else if(type == mask_enum::mask_bottom_right)
|
|
os << "b(" << left << ":" << right << ")";
|
|
else
|
|
{
|
|
os << "g(" << y << ":" << x << ")";
|
|
}
|
|
}
|
|
|
|
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
|
{
|
|
ck_tile::index_t x_total = seqlen_k;
|
|
ck_tile::index_t y_total = seqlen_q;
|
|
mask_info tmp;
|
|
tmp.seqlen_q = seqlen_q;
|
|
tmp.seqlen_k = seqlen_k;
|
|
auto found_0 = str.find(':');
|
|
if(found_0 != std::string::npos)
|
|
{
|
|
std::string t = str.substr(0, found_0);
|
|
std::string v = str.substr(found_0 + 1);
|
|
if(t == "xt" || t == "xb")
|
|
{
|
|
// xformer style sliding window attn from top-left
|
|
ck_tile::index_t window_size = std::stoi(v);
|
|
ck_tile::index_t left_size = -1;
|
|
ck_tile::index_t right_size = 0;
|
|
ck_tile::index_t sink_size = 0;
|
|
if(window_size > 0)
|
|
{
|
|
left_size = window_size / 2;
|
|
right_size = window_size - 1 - left_size;
|
|
}
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
left_size, right_size, sink_size, y_total, x_total, t == "xt");
|
|
|
|
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = left_size;
|
|
tmp.right = right_size;
|
|
tmp.sink = 0;
|
|
}
|
|
else if(t == "t" || t == "b" || t == "g")
|
|
{
|
|
auto found_1 = v.find(",");
|
|
if(found_1 == std::string::npos)
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
tmp.type = mask_enum::window_generic;
|
|
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
|
auto found_2 = v.find(',', found_1 + 1);
|
|
ck_tile::index_t v1 = 0;
|
|
ck_tile::index_t sink = 0;
|
|
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
// TODO: some validation
|
|
if(t == "t")
|
|
{
|
|
if(found_2 != std::string::npos)
|
|
{
|
|
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
|
sink = atoi(v.substr(found_2 + 1).c_str());
|
|
}
|
|
else
|
|
{
|
|
v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
sink = 0;
|
|
}
|
|
tmp.type = mask_enum::mask_top_left;
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
v0, v1, sink, y_total, x_total, true);
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = v0;
|
|
tmp.right = v1;
|
|
tmp.sink = sink;
|
|
}
|
|
else if(t == "b")
|
|
{
|
|
if(found_2 != std::string::npos)
|
|
{
|
|
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
|
sink = atoi(v.substr(found_2 + 1).c_str());
|
|
}
|
|
else
|
|
{
|
|
v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
sink = 0;
|
|
}
|
|
tmp.type = mask_enum::mask_bottom_right;
|
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
|
v0, v1, sink, y_total, x_total, false);
|
|
tmp.y = r.at(ck_tile::number<0>{});
|
|
tmp.x = r.at(ck_tile::number<1>{});
|
|
tmp.left = v0;
|
|
tmp.right = v1;
|
|
tmp.sink = sink;
|
|
}
|
|
else if(t == "g")
|
|
{
|
|
tmp.type = mask_enum::window_generic;
|
|
tmp.y = v0;
|
|
tmp.x = v1;
|
|
tmp.left = v0; // TODO: don't use this?
|
|
tmp.right = v1;
|
|
tmp.sink = 0;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
}
|
|
else if(str == "0")
|
|
{
|
|
tmp.type = mask_enum::no_mask;
|
|
tmp.left = -1;
|
|
tmp.right = -1;
|
|
tmp.sink = 0;
|
|
}
|
|
else if(str == "1" || str == "t")
|
|
{
|
|
tmp.type = mask_enum::mask_top_left;
|
|
tmp.y = seqlen_q;
|
|
tmp.x = 1;
|
|
tmp.left = -1;
|
|
tmp.right = 0;
|
|
tmp.sink = 0;
|
|
}
|
|
else if(str == "2" || str == "b")
|
|
{
|
|
tmp.type = mask_enum::mask_bottom_right;
|
|
tmp.y = seqlen_q;
|
|
tmp.x = seqlen_k - seqlen_q + 1;
|
|
tmp.left = -1;
|
|
tmp.right = 0;
|
|
tmp.sink = 0;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid mask value: " + str);
|
|
}
|
|
return tmp;
|
|
}
|
|
|
|
ck_tile::index_t get_unmaskarea() const
|
|
{
|
|
if(type == mask_enum::no_mask)
|
|
return seqlen_q * seqlen_k;
|
|
ck_tile::index_t area = 0;
|
|
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
|
{
|
|
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
|
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
|
if(x_end > x_start)
|
|
{
|
|
area += (x_end - x_start);
|
|
}
|
|
}
|
|
return area;
|
|
}
|
|
|
|
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi)
|
|
{
|
|
mi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|