Files
composable_kernel/codegen/test/include/fmha_fwd_ref.hpp
music-dino 359f664b25 [rocm-libraries] ROCm/rocm-libraries#6086 (commit d25d8cc)
[CK_TILE] Implement RTC API for a subset of FMHA
 functionality for MGX (#6086)

## Motivation

Introduce a wrapper for the FmhaFwdKernel, for use in real time
compilation in MIGraphX.

## Technical Details

The intent of the API is to provide multiple instances of the
FmhaFwdKernelWrapper, suitable for a particular problem definition.
At the moment the wrapper only supports bias and causal masking, feature
expansion will come in a future pr.
The usage pattern is, in short:

1.  Define fmha_fwd::Problem (input dimensions, data type, etc)
2. Fetch Solutions for target architecture (currently only gfx942) based
on Problem.
The solutions contain a map of template -> template parameter and can be
converted to a string representing the full instantiation of
FmhFwdKernelWrapper e.g. `ck_tile::FmhaFwdWrapper<ck_tile::fp16_t, 128,
64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, false,
true, false, true, true, true, true, ck_tile::FmhaPipelineTag::QR>`
3. The instance can then be used in an RTC kernel. The kernel needs to:
* Construct a Descriptor (containing descriptions of all input tensors)
* Call IsValid() on the descriptor to check if the instance is
applicable. Note that this is constexpr by design so that it can fail
the kernel compilation as a signal that the kernel is not applicable.
    * Pass the descriptor and input pointers to the wrapper Run method.

A more detailed example of usage can be found in
codegen/test/fmh_fwd.cpp

Beside work on creating the wrapper and the supporting API, the PR also
contains some changes necessary to enable compilation with HIPRTC.
The contents of the CK tile headers are embedded in a binary file which
is used to pass the header files as strings to HIPRTC.
Many of the ck tile headers contain host only code which leads to
compilation failures.
ck_tile_headers_preprocessor goes through the embedded headers and
removes the bodies of host only functions, thereby eliminating the
compilation failures.
## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-11 16:22:37 +00:00

126 lines
3.9 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <vector>
namespace ck {
namespace host {
namespace device_fmha_fwd {
struct FmhaFwdRefParams
{
std::size_t batch;
std::size_t nhead;
std::size_t M; // seqlen_q
std::size_t N; // seqlen_k
std::size_t K; // hdim_q
std::size_t O; // hdim_v
float scale_s;
std::size_t q_stride_batch;
std::size_t q_stride_nhead;
std::size_t q_stride_m;
std::size_t k_stride_batch;
std::size_t k_stride_nhead;
std::size_t k_stride_n;
std::size_t v_stride_batch;
std::size_t v_stride_nhead;
std::size_t v_stride_n;
std::size_t o_stride_batch;
std::size_t o_stride_nhead;
std::size_t o_stride_m;
std::size_t bias_stride_batch = 0;
std::size_t bias_stride_nhead = 0;
std::size_t bias_stride_m = 0;
};
// O = softmax(Q @ K^T * scale_s + bias) @ V
// bias is optional (nullptr = no bias)
inline void cpu_attention_ref(const std::vector<float>& q,
const std::vector<float>& k,
const std::vector<float>& v,
std::vector<float>& o,
const std::vector<float>* bias,
const FmhaFwdRefParams& p)
{
for(std::size_t b = 0; b < p.batch; ++b)
{
for(std::size_t h = 0; h < p.nhead; ++h)
{
const float* q_ptr = q.data() + b * p.q_stride_batch + h * p.q_stride_nhead;
const float* k_ptr = k.data() + b * p.k_stride_batch + h * p.k_stride_nhead;
const float* v_ptr = v.data() + b * p.v_stride_batch + h * p.v_stride_nhead;
const float* bias_ptr =
bias ? (bias->data() + b * p.bias_stride_batch + h * p.bias_stride_nhead) : nullptr;
float* o_ptr = o.data() + b * p.o_stride_batch + h * p.o_stride_nhead;
for(std::size_t m = 0; m < p.M; ++m)
{
// Q[m,:] @ K^T -> [N]
std::vector<float> scores(p.N);
for(std::size_t n = 0; n < p.N; ++n)
{
float dot = 0.0f;
for(std::size_t kk = 0; kk < p.K; ++kk)
{
dot += q_ptr[m * p.q_stride_m + kk] * k_ptr[n * p.k_stride_n + kk];
}
scores[n] = dot * p.scale_s;
if(bias_ptr)
{
scores[n] += bias_ptr[m * p.bias_stride_m + n];
}
}
// Softmax
float max_score = *std::max_element(scores.begin(), scores.end());
float sum_exp = 0.0f;
for(std::size_t n = 0; n < p.N; ++n)
{
scores[n] = std::exp(scores[n] - max_score);
sum_exp += scores[n];
}
for(std::size_t n = 0; n < p.N; ++n)
{
scores[n] /= sum_exp;
}
// Output: attn @ V -> [O]
for(std::size_t oo = 0; oo < p.O; ++oo)
{
float val = 0.0f;
for(std::size_t n = 0; n < p.N; ++n)
{
val += scores[n] * v_ptr[n * p.v_stride_n + oo];
}
o_ptr[m * p.o_stride_m + oo] = val;
}
}
}
}
}
inline void cpu_attention_ref(const std::vector<float>& q,
const std::vector<float>& k,
const std::vector<float>& v,
std::vector<float>& o,
const FmhaFwdRefParams& p)
{
cpu_attention_ref(q, k, v, o, nullptr, p);
}
} // namespace device_fmha_fwd
} // namespace host
} // namespace ck