mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[CK_TILE] Implement RTC API for a subset of FMHA functionality for MGX (#6086) ## Motivation Introduce a wrapper for the FmhaFwdKernel, for use in real time compilation in MIGraphX. ## Technical Details The intent of the API is to provide multiple instances of the FmhaFwdKernelWrapper, suitable for a particular problem definition. At the moment the wrapper only supports bias and causal masking, feature expansion will come in a future pr. The usage pattern is, in short: 1. Define fmha_fwd::Problem (input dimensions, data type, etc) 2. Fetch Solutions for target architecture (currently only gfx942) based on Problem. The solutions contain a map of template -> template parameter and can be converted to a string representing the full instantiation of FmhFwdKernelWrapper e.g. `ck_tile::FmhaFwdWrapper<ck_tile::fp16_t, 128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, false, true, false, true, true, true, true, ck_tile::FmhaPipelineTag::QR>` 3. The instance can then be used in an RTC kernel. The kernel needs to: * Construct a Descriptor (containing descriptions of all input tensors) * Call IsValid() on the descriptor to check if the instance is applicable. Note that this is constexpr by design so that it can fail the kernel compilation as a signal that the kernel is not applicable. * Pass the descriptor and input pointers to the wrapper Run method. A more detailed example of usage can be found in codegen/test/fmh_fwd.cpp Beside work on creating the wrapper and the supporting API, the PR also contains some changes necessary to enable compilation with HIPRTC. The contents of the CK tile headers are embedded in a binary file which is used to pass the header files as strings to HIPRTC. Many of the ck tile headers contain host only code which leads to compilation failures. ck_tile_headers_preprocessor goes through the embedded headers and removes the bodies of host only functions, thereby eliminating the compilation failures. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
172 lines
4.8 KiB
C++
172 lines
4.8 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/host/headers.hpp"
|
|
#include <rtc/compile_kernel.hpp>
|
|
#include <rtc/hip.hpp>
|
|
#include <test.hpp>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iterator>
|
|
#include <numeric>
|
|
#include <random>
|
|
#include <unordered_set>
|
|
|
|
inline std::vector<rtc::src_file> create_headers_for_test()
|
|
{
|
|
auto ck_headers = ck::host::GetHeaders();
|
|
std::vector<rtc::src_file> result;
|
|
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
|
|
std::string content;
|
|
content.reserve(p.second.size() + 1);
|
|
content.push_back(' '); // We need a whitespace before the content for hipRTC to work
|
|
content.append(p.second.data(), p.second.size());
|
|
return rtc::src_file{p.first, std::move(content)};
|
|
});
|
|
return result;
|
|
}
|
|
|
|
inline const std::vector<rtc::src_file>& get_headers_for_test()
|
|
{
|
|
static const std::vector<rtc::src_file> headers = create_headers_for_test();
|
|
return headers;
|
|
}
|
|
|
|
inline std::vector<rtc::src_file> create_tile_headers_for_test()
|
|
{
|
|
auto headers = ck::host::GetTileHeaders();
|
|
std::vector<rtc::src_file> result;
|
|
std::transform(headers.begin(), headers.end(), std::back_inserter(result), [](auto& p) {
|
|
// Legacy workaround: hipRTC requires a whitespace before the content (reason unknown)
|
|
return rtc::src_file{p.first, " " + std::move(p.second)};
|
|
});
|
|
return result;
|
|
}
|
|
|
|
inline const std::vector<rtc::src_file>& get_tile_headers_for_test()
|
|
{
|
|
static const std::vector<rtc::src_file> headers = create_tile_headers_for_test();
|
|
return headers;
|
|
}
|
|
|
|
template <typename V>
|
|
std::size_t GetSize(V mLens, V mStrides)
|
|
{
|
|
std::size_t space = 1;
|
|
for(std::size_t i = 0; i < mLens.Size(); ++i)
|
|
{
|
|
if(mLens[i] == 0)
|
|
continue;
|
|
|
|
space += (mLens[i] - 1) * mStrides[i];
|
|
}
|
|
return space;
|
|
}
|
|
|
|
template <class T>
|
|
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
|
|
{
|
|
rtc::buffer<T> result(n);
|
|
std::mt19937 gen(seed);
|
|
std::uniform_real_distribution<double> dis(-1.0);
|
|
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
|
|
return result;
|
|
}
|
|
|
|
template <class T, typename V>
|
|
std::enable_if_t<!std::is_integral_v<V>, rtc::buffer<T>>
|
|
generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
|
|
{
|
|
std::size_t space = GetSize(mLens, mStrides);
|
|
return generate_buffer<T>(space, seed);
|
|
}
|
|
|
|
template <class T, class U>
|
|
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
|
|
{
|
|
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
|
|
return fabs(x - y) < atol + rtol * fabs(y);
|
|
});
|
|
}
|
|
|
|
inline std::string classify(double x)
|
|
{
|
|
switch(std::fpclassify(x))
|
|
{
|
|
case FP_INFINITE: return "inf";
|
|
case FP_NAN: return "nan";
|
|
case FP_NORMAL: return "normal";
|
|
case FP_SUBNORMAL: return "subnormal";
|
|
case FP_ZERO: return "zero";
|
|
default: return "unknown";
|
|
}
|
|
}
|
|
|
|
template <class Buffer>
|
|
void print_classification(const Buffer& x)
|
|
{
|
|
std::unordered_set<std::string> result;
|
|
for(const auto& i : x)
|
|
result.insert(classify(i));
|
|
for(const auto& c : result)
|
|
std::cout << c << ", ";
|
|
std::cout << std::endl;
|
|
}
|
|
|
|
template <class Buffer>
|
|
void print_statistics(const Buffer& x)
|
|
{
|
|
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
|
|
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
|
|
double num_elements = x.size();
|
|
auto mean =
|
|
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
|
|
auto stddev = std::sqrt(
|
|
std::accumulate(x.begin(),
|
|
x.end(),
|
|
double{0.0},
|
|
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
|
|
num_elements);
|
|
std::cout << "Mean: " << mean << ", ";
|
|
std::cout << "StdDev: " << stddev << "\n";
|
|
}
|
|
|
|
template <class Buffer>
|
|
void print_preview(const Buffer& x)
|
|
{
|
|
if(x.size() <= 10)
|
|
{
|
|
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
|
|
}
|
|
else
|
|
{
|
|
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
|
|
std::cout << "..., ";
|
|
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
|
|
}
|
|
std::cout << std::endl;
|
|
}
|
|
|
|
template <class T>
|
|
struct check_all
|
|
{
|
|
rtc::buffer<T> data{};
|
|
bool operator()(const rtc::buffer<T>& x)
|
|
{
|
|
if(data.empty())
|
|
{
|
|
data = x;
|
|
return true;
|
|
}
|
|
return allclose(data, x);
|
|
}
|
|
};
|
|
|
|
template <class Solution>
|
|
auto report(const Solution& solution, bool pass)
|
|
{
|
|
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
|
|
}
|