mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CK_TILE] Implement RTC API for a subset of FMHA functionality for MGX (#6086) ## Motivation Introduce a wrapper for the FmhaFwdKernel, for use in real time compilation in MIGraphX. ## Technical Details The intent of the API is to provide multiple instances of the FmhaFwdKernelWrapper, suitable for a particular problem definition. At the moment the wrapper only supports bias and causal masking, feature expansion will come in a future pr. The usage pattern is, in short: 1. Define fmha_fwd::Problem (input dimensions, data type, etc) 2. Fetch Solutions for target architecture (currently only gfx942) based on Problem. The solutions contain a map of template -> template parameter and can be converted to a string representing the full instantiation of FmhFwdKernelWrapper e.g. `ck_tile::FmhaFwdWrapper<ck_tile::fp16_t, 128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, false, true, false, true, true, true, true, ck_tile::FmhaPipelineTag::QR>` 3. The instance can then be used in an RTC kernel. The kernel needs to: * Construct a Descriptor (containing descriptions of all input tensors) * Call IsValid() on the descriptor to check if the instance is applicable. Note that this is constexpr by design so that it can fail the kernel compilation as a signal that the kernel is not applicable. * Pass the descriptor and input pointers to the wrapper Run method. A more detailed example of usage can be found in codegen/test/fmh_fwd.cpp Beside work on creating the wrapper and the supporting API, the PR also contains some changes necessary to enable compilation with HIPRTC. The contents of the CK tile headers are embedded in a binary file which is used to pass the header files as strings to HIPRTC. Many of the ck tile headers contain host only code which leads to compilation failures. ck_tile_headers_preprocessor goes through the embedded headers and removes the bodies of host only functions, thereby eliminating the compilation failures. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
836 lines
51 KiB
C++
836 lines
51 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "ck/host/device_fmha_fwd/problem.hpp"
|
|
#include "ck/host/device_fmha_fwd/operation.hpp"
|
|
#include "ck/host/stringutils.hpp"
|
|
#include "ck/host/utils.hpp"
|
|
#include "ck/host/headers.hpp"
|
|
#include "common.hpp"
|
|
#include "fmha_fwd_ref.hpp"
|
|
#include <rtc/compile_kernel.hpp>
|
|
#include <rtc/hip.hpp>
|
|
#include <test.hpp>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using ck::host::Solution;
|
|
using ck::host::device_fmha_fwd::cpu_attention_ref;
|
|
using ck::host::device_fmha_fwd::FmhaFwdRefParams;
|
|
using ck::host::device_fmha_fwd::Problem;
|
|
|
|
using half = _Float16;
|
|
|
|
const std::string kernel_template = R"__ck__(
|
|
#include <${include}>
|
|
|
|
using KernelType = ${template};
|
|
|
|
extern "C" __launch_bounds__(KernelType::Kernel::kBlockSize, KernelType::Kernel::kBlockPerCu)
|
|
__global__ void f(const ${dtype}* q, const ${dtype}* k, const ${dtype}* v, const ${dtype}* bias, ${dtype}* o) {
|
|
|
|
constexpr float scale_s = ${scale_s};
|
|
|
|
using Kernel = KernelType;
|
|
|
|
constexpr auto desc = Kernel::make_descriptor(
|
|
// Q
|
|
ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${k}),
|
|
ck_tile::make_tuple(${q_stride_batch}, ${q_stride_nhead}, ${q_stride_m}),
|
|
// K
|
|
ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${k}),
|
|
ck_tile::make_tuple(${k_stride_batch}, ${k_stride_nhead}, ${k_stride_n}),
|
|
// V
|
|
ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${o}),
|
|
ck_tile::make_tuple(${v_stride_batch}, ${v_stride_nhead}, ${v_stride_n}),
|
|
// O
|
|
ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${o}),
|
|
ck_tile::make_tuple(${o_stride_batch}, ${o_stride_nhead}, ${o_stride_m}),
|
|
// Bias
|
|
ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${n}),
|
|
ck_tile::make_tuple(${bias_stride_batch}, ${bias_stride_nhead}, ${bias_stride_m}));
|
|
|
|
static_assert(desc.IsValid(), "Invalid FMHA kernel configuration");
|
|
|
|
Kernel::Run(desc, scale_s, q, k, v, bias, o);
|
|
}
|
|
)__ck__";
|
|
|
|
std::string make_kernel_source(const Problem& prob,
|
|
const Solution& solution,
|
|
const FmhaFwdRefParams& ref_params)
|
|
{
|
|
auto template_string = solution.ToTemplateString();
|
|
std::cout << "template_string: " << template_string << std::endl;
|
|
return ck::host::InterpolateString(
|
|
kernel_template,
|
|
{{"include", prob.GetIncludeHeader()},
|
|
{"template", solution.ToTemplateString()},
|
|
{"dtype", "ck_tile::fp16_t"},
|
|
{"batch", std::to_string(ref_params.batch)},
|
|
{"nhead", std::to_string(ref_params.nhead)},
|
|
{"m", std::to_string(ref_params.M)},
|
|
{"n", std::to_string(ref_params.N)},
|
|
{"k", std::to_string(ref_params.K)},
|
|
{"o", std::to_string(ref_params.O)},
|
|
{"q_stride_batch", std::to_string(ref_params.q_stride_batch)},
|
|
{"q_stride_nhead", std::to_string(ref_params.q_stride_nhead)},
|
|
{"q_stride_m", std::to_string(ref_params.q_stride_m)},
|
|
{"k_stride_batch", std::to_string(ref_params.k_stride_batch)},
|
|
{"k_stride_nhead", std::to_string(ref_params.k_stride_nhead)},
|
|
{"k_stride_n", std::to_string(ref_params.k_stride_n)},
|
|
{"v_stride_batch", std::to_string(ref_params.v_stride_batch)},
|
|
{"v_stride_nhead", std::to_string(ref_params.v_stride_nhead)},
|
|
{"v_stride_n", std::to_string(ref_params.v_stride_n)},
|
|
{"o_stride_batch", std::to_string(ref_params.o_stride_batch)},
|
|
{"o_stride_nhead", std::to_string(ref_params.o_stride_nhead)},
|
|
{"o_stride_m", std::to_string(ref_params.o_stride_m)},
|
|
{"bias_stride_batch", std::to_string(ref_params.bias_stride_batch)},
|
|
{"bias_stride_nhead", std::to_string(ref_params.bias_stride_nhead)},
|
|
{"bias_stride_m", std::to_string(ref_params.bias_stride_m)},
|
|
{"scale_s", std::to_string(ref_params.scale_s) + "f"}});
|
|
}
|
|
|
|
FmhaFwdRefParams make_ref_params(const Problem& prob, float scale_s)
|
|
{
|
|
FmhaFwdRefParams p;
|
|
p.batch = prob.batch;
|
|
p.nhead = prob.nhead;
|
|
p.M = prob.M;
|
|
p.N = prob.N;
|
|
p.K = prob.K;
|
|
p.O = prob.O;
|
|
p.scale_s = scale_s;
|
|
|
|
// Q - [batch, nhead, M, K]
|
|
p.q_stride_m = prob.K;
|
|
p.q_stride_nhead = prob.M * prob.K;
|
|
p.q_stride_batch = prob.nhead * prob.M * prob.K;
|
|
|
|
// K - [batch, nhead, N, K]
|
|
p.k_stride_n = prob.K;
|
|
p.k_stride_nhead = prob.N * prob.K;
|
|
p.k_stride_batch = prob.nhead * prob.N * prob.K;
|
|
|
|
// V - [batch, nhead, N, O]
|
|
p.v_stride_n = prob.O;
|
|
p.v_stride_nhead = prob.N * prob.O;
|
|
p.v_stride_batch = prob.nhead * prob.N * prob.O;
|
|
|
|
// O - [batch, nhead, M, O] contiguous
|
|
p.o_stride_m = prob.O;
|
|
p.o_stride_nhead = prob.M * prob.O;
|
|
p.o_stride_batch = prob.nhead * prob.M * prob.O;
|
|
|
|
return p;
|
|
}
|
|
|
|
std::pair<dim3, dim3> get_launch_dims(const Solution& solution, const Problem& prob)
|
|
{
|
|
// Block tile sizes (from TileFmhaShape BlockTile sequence)
|
|
auto bm0 = solution.GetTemplateParameter<std::size_t>("BM0");
|
|
auto bn1 = solution.GetTemplateParameter<std::size_t>("BN1");
|
|
|
|
// Block warps for Gemm0 - sequence<RM0, RN0, RK0>
|
|
auto rm0 = solution.GetTemplateParameter<std::size_t>("RM0");
|
|
auto rn0 = solution.GetTemplateParameter<std::size_t>("RN0");
|
|
auto rk0 = solution.GetTemplateParameter<std::size_t>("RK0");
|
|
|
|
// Block warps for Gemm1 - sequence<RM1, RN1, RK1>
|
|
auto rm1 = solution.GetTemplateParameter<std::size_t>("RM1");
|
|
auto rn1 = solution.GetTemplateParameter<std::size_t>("RN1");
|
|
auto rk1 = solution.GetTemplateParameter<std::size_t>("RK1");
|
|
|
|
const std::size_t warp_size = 64; // gfx9
|
|
const std::size_t num_warps = std::max(rm0 * rn0 * rk0, rm1 * rn1 * rk1);
|
|
const std::size_t block_size = num_warps * warp_size;
|
|
|
|
// Grid dimensions: (nhead, num_m_tiles * num_o_tiles, batch)
|
|
const auto grid_m = ck::host::integer_divide_ceil(prob.M, bm0);
|
|
const auto grid_o = ck::host::integer_divide_ceil(prob.O, bn1);
|
|
|
|
dim3 grid(prob.nhead, grid_m * grid_o, prob.batch);
|
|
dim3 block(block_size, 1, 1);
|
|
|
|
return {grid, block};
|
|
}
|
|
|
|
TEST_CASE(test_fmha_fwd_simple_validation)
|
|
{
|
|
ck::host::device_fmha_fwd::Problem prob;
|
|
prob.M = 24; // seqlen_q
|
|
prob.N = 32; // seqlen_k
|
|
prob.K = 8; // hdim_q (must be multiple of 8)
|
|
prob.O = 16; // hdim_v
|
|
prob.batch = 2;
|
|
prob.nhead = 1;
|
|
prob.dtype = ck::host::DataType::Half;
|
|
prob.is_v_rowmajor = true;
|
|
prob.is_causal = false;
|
|
prob.has_bias = false;
|
|
|
|
const float scale_s = 1.0f;
|
|
|
|
auto solutions = prob.GetSolutions("gfx90a");
|
|
|
|
EXPECT(!solutions.empty());
|
|
|
|
const std::vector<float> q_data = {
|
|
-0.125460f, 0.450714f, 0.231994f, 0.098658f, -0.343981f, -0.344005f, -0.441916f,
|
|
0.366176f, 0.101115f, 0.208073f, -0.479416f, 0.469910f, 0.332443f, -0.287661f,
|
|
-0.318175f, -0.316595f, -0.195758f, 0.024756f, -0.068055f, -0.208771f, 0.111853f,
|
|
-0.360506f, -0.207855f, -0.133638f, -0.043930f, 0.285176f, -0.300326f, 0.014234f,
|
|
0.092415f, -0.453550f, 0.107545f, -0.329476f, -0.434948f, 0.448886f, 0.465632f,
|
|
0.308397f, -0.195386f, -0.402328f, 0.184233f, -0.059848f, -0.377962f, -0.004823f,
|
|
-0.465611f, 0.409320f, -0.241220f, 0.162522f, -0.188289f, 0.020068f, 0.046710f,
|
|
-0.315146f, 0.469585f, 0.275133f, 0.439499f, 0.394827f, 0.097900f, 0.421874f,
|
|
-0.411507f, -0.304017f, -0.454773f, -0.174670f, -0.111323f, -0.228651f, 0.328737f,
|
|
-0.143247f, -0.219065f, 0.042696f, -0.359076f, 0.302197f, -0.425449f, 0.486887f,
|
|
0.272245f, -0.301284f, -0.494478f, 0.315461f, 0.206857f, 0.229007f, 0.271270f,
|
|
-0.425955f, -0.141534f, -0.384131f, 0.363103f, 0.123298f, -0.169102f, -0.436442f,
|
|
-0.189018f, -0.174817f, 0.229606f, 0.137557f, 0.387213f, -0.027785f, -0.380406f,
|
|
0.213245f, 0.260785f, 0.061277f, 0.270967f, -0.006204f, 0.022733f, -0.072459f,
|
|
-0.474581f, -0.392109f, -0.468571f, 0.136410f, -0.185644f, 0.008571f, 0.407566f,
|
|
-0.250708f, -0.089617f, 0.255551f, -0.271202f, -0.423020f, -0.210249f, -0.338779f,
|
|
0.429698f, 0.308120f, 0.133404f, 0.371461f, 0.303672f, -0.313430f, 0.392559f,
|
|
0.039342f, 0.307440f, 0.396091f, -0.181997f, -0.389948f, -0.272065f, -0.072892f,
|
|
0.318015f, 0.360731f, -0.493048f, 0.010747f, -0.082589f, -0.277892f, -0.380135f,
|
|
-0.162385f, 0.442910f, -0.176797f, 0.018791f, 0.203019f, -0.136370f, 0.471782f,
|
|
0.462447f, -0.248218f, -0.002751f, -0.199122f, -0.215160f, -0.463113f, 0.109564f,
|
|
0.002679f, -0.448521f, -0.221354f, 0.408266f, -0.260438f, -0.355105f, -0.010547f,
|
|
0.485650f, -0.257945f, 0.172136f, 0.261620f, -0.262362f, 0.228216f, -0.132217f,
|
|
0.132306f, 0.133530f, 0.035775f, -0.409710f, 0.335303f, -0.179220f, -0.313481f,
|
|
-0.459225f, 0.090893f, 0.177564f, -0.483412f, 0.012093f, -0.273504f, 0.145173f,
|
|
-0.325634f, 0.190938f, -0.113265f, 0.436730f, -0.362479f, -0.158934f, -0.386526f,
|
|
0.424694f, 0.377339f, -0.242058f, 0.159984f, 0.317222f, 0.055201f, 0.029651f,
|
|
-0.258148f, -0.406897f, 0.397216f, 0.400418f, 0.133101f, -0.160970f, -0.150790f,
|
|
0.225956f, 0.397110f, 0.387086f, 0.279876f, 0.142032f, -0.415860f, -0.338371f,
|
|
0.398554f, 0.106429f, -0.490803f, -0.398528f, 0.163502f, -0.494938f, -0.339192f,
|
|
0.048734f, 0.191895f, 0.151961f, -0.275731f, 0.212179f, -0.262751f, -0.174600f,
|
|
0.246491f, 0.149633f, 0.349223f, 0.157613f, 0.068309f, -0.406325f, -0.132284f,
|
|
-0.234798f, -0.256010f, 0.473011f, -0.106902f, 0.392047f, 0.131139f, 0.294811f,
|
|
0.002637f, 0.076904f, -0.007482f, -0.304757f, 0.222452f, -0.219228f, -0.475684f,
|
|
0.145472f, -0.322889f, 0.440459f, 0.453929f, 0.414864f, -0.129841f, -0.484543f,
|
|
0.428319f, -0.071816f, 0.466655f, 0.463620f, 0.353009f, -0.205551f, -0.114902f,
|
|
0.351137f, -0.183078f, -0.330507f, 0.056801f, 0.436155f, 0.196030f, 0.070061f,
|
|
-0.402824f, 0.115007f, 0.490054f, -0.359916f, 0.018330f, 0.377373f, 0.240769f,
|
|
0.197016f, 0.202484f, -0.140509f, -0.206408f, 0.309361f, 0.310113f, 0.367072f,
|
|
0.413241f, 0.011342f, 0.001516f, 0.298295f, 0.149964f, 0.201967f, 0.295793f,
|
|
0.390005f, -0.162005f, -0.124417f, -0.406018f, 0.078280f, -0.464058f, -0.034402f,
|
|
0.042645f, -0.213459f, 0.090833f, -0.469500f, -0.462652f, 0.322601f, -0.139809f,
|
|
-0.372939f, 0.022243f, 0.269994f, -0.284179f, 0.122890f, -0.414653f, -0.448318f,
|
|
0.031355f, 0.040635f, 0.137430f, 0.226091f, 0.475852f, 0.016300f, -0.177044f,
|
|
0.295186f, -0.229168f, -0.061029f, -0.421544f, -0.474649f, 0.462648f, 0.335980f,
|
|
0.195974f, -0.091047f, -0.326706f, -0.343563f, -0.249757f, 0.049227f, 0.214596f,
|
|
0.160197f, -0.220066f, 0.454865f, 0.237897f, 0.054354f, 0.111721f, -0.080400f,
|
|
-0.252269f, -0.144027f, 0.257846f, -0.485607f, -0.383927f, -0.453997f, -0.459271f,
|
|
0.355461f, 0.203658f, -0.025826f, -0.402166f, -0.008384f, -0.026528f, -0.326798f,
|
|
-0.066148f, -0.101495f, 0.115850f, 0.135094f, -0.454696f, -0.125387f, 0.125860f,
|
|
0.003136f, 0.356490f, 0.158694f, -0.337066f, -0.429431f, 0.142419f, -0.473489f,
|
|
0.085776f, 0.440230f, 0.075474f, -0.111830f, 0.143288f, -0.041747f, 0.045617f,
|
|
0.441465f, -0.113897f, 0.461191f, 0.405351f, -0.304209f, -0.430639f, -0.399222f,
|
|
-0.481778f, -0.405557f, 0.183007f, -0.428811f, -0.181024f, 0.344875f, -0.476728f,
|
|
0.314468f, -0.218145f, -0.381835f, 0.196737f, 0.128943f, 0.377472f,
|
|
};
|
|
|
|
const std::vector<float> k_data = {
|
|
0.235071f, 0.303481f, -0.217965f, -0.322560f, 0.250615f, 0.306835f, 0.490505f,
|
|
-0.087382f, -0.127982f, 0.276413f, -0.159196f, 0.430757f, 0.358413f, -0.071006f,
|
|
0.250871f, 0.254543f, -0.396876f, 0.402553f, 0.005252f, 0.326457f, -0.179950f,
|
|
0.395523f, -0.110798f, -0.489162f, 0.405382f, -0.408713f, -0.180686f, 0.450062f,
|
|
0.450607f, 0.073438f, 0.131837f, -0.051554f, -0.206789f, -0.171335f, 0.172518f,
|
|
0.252375f, 0.291579f, 0.289618f, -0.408794f, -0.005580f, -0.442441f, 0.049529f,
|
|
-0.058470f, 0.387704f, -0.149085f, -0.382933f, -0.357008f, 0.261511f, 0.118218f,
|
|
-0.398877f, -0.415893f, 0.200969f, -0.427237f, 0.321860f, 0.206242f, -0.418651f,
|
|
-0.415162f, 0.486640f, -0.125729f, -0.129358f, 0.312800f, 0.447249f, 0.486001f,
|
|
0.253378f, -0.123740f, -0.416499f, 0.277147f, 0.058404f, -0.075778f, 0.406354f,
|
|
-0.388803f, -0.007375f, -0.488646f, -0.031339f, -0.443697f, -0.381182f, -0.382474f,
|
|
0.149210f, 0.246045f, 0.083369f, 0.462173f, -0.125129f, -0.214288f, 0.368599f,
|
|
-0.276404f, 0.463223f, -0.487846f, 0.469879f, -0.456840f, 0.391143f, 0.027701f,
|
|
0.492965f, -0.426203f, 0.053854f, 0.469303f, 0.023098f, 0.129399f, 0.195749f,
|
|
-0.045459f, 0.127558f, 0.084314f, 0.401158f, -0.454554f, -0.219037f, 0.450411f,
|
|
0.390264f, -0.044343f, 0.120133f, -0.222619f, -0.311879f, -0.036302f, -0.146648f,
|
|
0.083656f, -0.422265f, 0.474395f, 0.486211f, 0.198162f, 0.036096f, -0.190472f,
|
|
0.313795f, 0.184731f, -0.337383f, 0.410927f, 0.322537f, 0.449800f, 0.225720f,
|
|
0.113415f, -0.081757f, 0.432728f, 0.366064f, -0.454781f, -0.473633f, -0.123537f,
|
|
0.310553f, 0.487276f, -0.349583f, 0.094131f, -0.119109f, 0.469914f, 0.342119f,
|
|
0.338329f, -0.031307f, -0.085180f, -0.226593f, -0.443624f, 0.364722f, 0.312901f,
|
|
0.499718f, 0.496637f, 0.055432f, 0.268987f, 0.444766f, 0.349647f, -0.252652f,
|
|
-0.049456f, -0.370841f, 0.454051f, 0.106175f, -0.271357f, 0.171701f, 0.118128f,
|
|
-0.141837f, -0.386442f, 0.171573f, 0.020308f, 0.272318f, 0.020164f, 0.352181f,
|
|
0.051907f, 0.060938f, 0.376654f, -0.096517f, -0.365985f, -0.471217f, 0.255137f,
|
|
0.120310f, 0.204080f, -0.287036f, -0.363629f, -0.485455f, -0.149412f, 0.089918f,
|
|
-0.107756f, -0.062525f, 0.404159f, -0.151745f, 0.013989f, 0.283653f, -0.103457f,
|
|
0.122087f, 0.362364f, 0.449521f, -0.352927f, 0.426588f, -0.007884f, -0.241756f,
|
|
-0.040864f, 0.480033f, -0.007382f, -0.171248f, 0.133401f, -0.259854f, -0.424137f,
|
|
-0.371120f, -0.371954f, -0.348097f, -0.361173f, 0.140875f, -0.318120f, -0.154333f,
|
|
0.396788f, -0.026038f, 0.167558f, -0.327680f, -0.307711f, -0.459131f, -0.331065f,
|
|
-0.221410f, -0.322990f, -0.411297f, -0.379364f, -0.039221f, -0.293666f, -0.135730f,
|
|
0.003417f, 0.190395f, -0.460688f, 0.299410f, 0.127900f, -0.418241f, 0.373579f,
|
|
0.420872f, -0.438922f, -0.223122f, 0.306201f, 0.248260f, -0.315479f, -0.290651f,
|
|
-0.129528f, -0.015477f, 0.118255f, -0.131086f, -0.037465f, 0.247471f, -0.463317f,
|
|
-0.247563f, 0.213350f, 0.395207f, 0.011677f, 0.032113f, -0.392828f, -0.052588f,
|
|
0.032617f, -0.257529f, -0.230757f, -0.122716f, -0.479929f, -0.177921f, -0.288552f,
|
|
-0.172503f, -0.380238f, 0.390527f, 0.093592f, 0.179102f, 0.289171f, -0.001558f,
|
|
-0.413080f, 0.037107f, 0.086841f, 0.245439f, -0.068340f, -0.372420f, -0.216224f,
|
|
-0.136918f, 0.145917f, 0.070778f, -0.143903f, 0.486515f, 0.105775f, -0.262773f,
|
|
-0.398218f, -0.347141f, -0.254042f, -0.339319f, -0.313433f, -0.214905f, -0.326626f,
|
|
0.396765f, -0.419766f, 0.024511f, -0.089603f, 0.482379f, -0.387961f, -0.102144f,
|
|
0.469470f, 0.365507f, 0.317072f, -0.242097f, -0.329112f, 0.168643f, 0.429376f,
|
|
0.056763f, 0.071613f, -0.220021f, 0.269493f, -0.312956f, -0.176321f, -0.074564f,
|
|
0.007610f, -0.257590f, -0.385163f, 0.110620f, -0.211369f, 0.081238f, -0.345637f,
|
|
-0.018860f, 0.032589f, -0.448176f, -0.163396f, -0.365585f, -0.436625f, 0.489960f,
|
|
-0.177646f, 0.309874f, -0.245359f, 0.181503f, 0.260228f, 0.095639f, -0.028424f,
|
|
-0.088159f, -0.151132f, 0.429529f, 0.330619f, 0.465027f, -0.375703f, 0.230867f,
|
|
0.438340f, -0.318767f, -0.433504f, 0.241121f, 0.074473f, 0.341829f, -0.360228f,
|
|
0.295267f, -0.298373f, -0.336344f, -0.335734f, 0.314575f, 0.165197f, 0.023065f,
|
|
-0.141170f, 0.377201f, -0.107555f, 0.316599f, -0.060865f, -0.123056f, -0.037320f,
|
|
-0.198622f, 0.247609f, 0.002720f, -0.267787f, 0.399575f, -0.116109f, 0.043553f,
|
|
0.406472f, 0.124238f, -0.383102f, 0.439832f, 0.127708f, -0.165094f, -0.360728f,
|
|
0.294025f, 0.120073f, 0.033461f, 0.393893f, 0.288597f, -0.348325f, -0.188278f,
|
|
-0.251511f, 0.243946f, -0.466468f, 0.069890f, 0.262459f, 0.376766f, -0.157918f,
|
|
0.321257f, -0.389368f, 0.346452f, -0.372511f, -0.102713f, 0.297295f, -0.350083f,
|
|
-0.270749f, 0.222253f, 0.220037f, 0.141148f, 0.193948f, 0.042724f, -0.248201f,
|
|
-0.154304f, -0.318402f, 0.408451f, 0.083392f, -0.099149f, -0.037994f, 0.447283f,
|
|
-0.346649f, 0.086230f, 0.005889f, 0.111454f, -0.481890f, 0.372124f, 0.432118f,
|
|
0.065133f, 0.196651f, 0.422499f, 0.207239f, -0.347461f, 0.076288f, 0.106715f,
|
|
-0.075869f, 0.236444f, 0.434367f, 0.425569f, -0.049161f, -0.386762f, 0.484841f,
|
|
0.338898f, -0.375337f, 0.420842f, 0.369896f, 0.018838f, 0.091275f, -0.100997f,
|
|
-0.445238f, -0.164803f, 0.302853f, -0.495368f, -0.166501f, -0.101831f, 0.037396f,
|
|
0.419856f, -0.153654f, -0.153047f, 0.237501f, -0.047782f, -0.275395f, -0.047560f,
|
|
-0.359143f, -0.323613f, -0.001632f, -0.081075f, 0.414846f, -0.137606f, 0.080588f,
|
|
0.132264f, -0.486906f, 0.163537f, -0.321964f, 0.461070f, -0.351337f, -0.085376f,
|
|
-0.414650f, 0.496874f, 0.002195f, 0.095385f, -0.432924f, 0.249960f, -0.290094f,
|
|
0.398054f, -0.294860f, -0.309312f, -0.463450f, -0.027933f, 0.064841f, -0.434291f,
|
|
0.275528f, -0.046711f, 0.024390f, -0.059237f, -0.099237f, 0.059640f, -0.344760f,
|
|
-0.318072f, 0.361786f, 0.446115f, -0.126691f, -0.229255f, 0.144000f, -0.091266f,
|
|
-0.474614f, -0.343847f, 0.215972f, 0.158924f, -0.472904f, -0.278028f, -0.268925f,
|
|
0.171893f, -0.480289f, -0.395891f, 0.299916f, -0.321455f, 0.152746f, -0.261817f,
|
|
-0.400559f, -0.256828f, 0.222267f, 0.355696f, 0.330220f, -0.102816f, 0.168085f,
|
|
-0.295016f,
|
|
};
|
|
|
|
const std::vector<float> v_data = {
|
|
-0.206852f, 0.396336f, -0.486998f, -0.414491f, -0.292114f, -0.473468f, -0.318565f,
|
|
0.083042f, -0.078575f, 0.392672f, 0.317444f, -0.158183f, -0.240577f, -0.120308f,
|
|
0.090295f, -0.231936f, 0.124149f, -0.090588f, 0.052047f, -0.063873f, -0.205534f,
|
|
0.448453f, 0.263606f, -0.359887f, 0.368468f, -0.012569f, 0.394552f, 0.299855f,
|
|
-0.074786f, -0.477531f, -0.231323f, 0.041634f, 0.133478f, -0.242112f, -0.360644f,
|
|
0.334930f, 0.484402f, 0.025690f, -0.328321f, -0.227693f, -0.481609f, 0.414299f,
|
|
-0.382249f, 0.076516f, -0.225945f, 0.054178f, 0.151420f, 0.329742f, -0.293579f,
|
|
-0.489004f, -0.363114f, 0.400019f, 0.373890f, 0.097413f, 0.100517f, 0.165037f,
|
|
-0.324629f, 0.414412f, -0.081229f, -0.116861f, 0.018918f, -0.453034f, -0.333717f,
|
|
0.238034f, -0.417201f, 0.103152f, -0.254651f, -0.110704f, -0.211306f, -0.144327f,
|
|
0.219046f, -0.202878f, 0.066405f, -0.023950f, 0.163671f, 0.436830f, 0.232572f,
|
|
-0.285060f, -0.468817f, -0.237736f, 0.095078f, -0.448574f, -0.003634f, 0.096843f,
|
|
-0.165756f, 0.270912f, -0.393402f, -0.424862f, 0.228189f, -0.004509f, 0.188402f,
|
|
-0.065173f, -0.253598f, 0.319102f, 0.299416f, 0.194696f, -0.227855f, 0.090231f,
|
|
-0.139026f, -0.408418f, 0.417314f, -0.363181f, 0.450237f, -0.053994f, -0.314867f,
|
|
0.041901f, 0.372946f, 0.232225f, 0.306561f, 0.158783f, 0.192277f, 0.349196f,
|
|
-0.250332f, -0.010575f, -0.278791f, 0.487668f, 0.444059f, -0.460573f, 0.205575f,
|
|
0.425248f, -0.319425f, 0.067945f, 0.415488f, -0.466054f, 0.197420f, -0.202651f,
|
|
0.424396f, 0.471058f, 0.444266f, -0.025786f, 0.362043f, 0.344549f, -0.180900f,
|
|
0.328915f, -0.462992f, 0.096270f, -0.269991f, -0.379433f, -0.423047f, 0.196289f,
|
|
-0.160125f, 0.224767f, -0.434644f, -0.184710f, 0.039491f, 0.290723f, -0.181248f,
|
|
0.125891f, 0.385978f, 0.115863f, -0.267041f, -0.475599f, 0.370099f, -0.478731f,
|
|
0.374702f, 0.028937f, 0.439068f, 0.298783f, 0.497934f, -0.149288f, 0.267188f,
|
|
-0.098069f, -0.020124f, 0.127505f, 0.373677f, 0.484083f, 0.268273f, -0.082233f,
|
|
-0.078643f, 0.237582f, -0.261223f, -0.389526f, -0.145378f, -0.212761f, -0.203692f,
|
|
-0.266392f, -0.457907f, -0.482126f, 0.487722f, -0.072227f, -0.115673f, 0.179647f,
|
|
-0.281746f, 0.449961f, 0.286345f, -0.410589f, -0.082419f, 0.379118f, 0.444732f,
|
|
-0.032598f, 0.113411f, -0.332966f, 0.491169f, -0.268328f, 0.442732f, 0.149647f,
|
|
0.107737f, 0.012689f, -0.269330f, -0.323472f, -0.279514f, -0.313562f, 0.279584f,
|
|
-0.149875f, -0.442157f, 0.469103f, 0.383786f, 0.427752f, 0.494908f, -0.326105f,
|
|
-0.103758f, 0.258238f, 0.196021f, -0.346104f, 0.315833f, -0.275559f, -0.276182f,
|
|
0.036974f, 0.092940f, 0.080086f, -0.408513f, 0.377461f, -0.234400f, -0.370485f,
|
|
0.388748f, 0.455651f, 0.362128f, 0.309516f, 0.155242f, 0.050857f, -0.413013f,
|
|
-0.091547f, -0.127311f, -0.240246f, 0.223420f, -0.004124f, -0.418954f, -0.279817f,
|
|
0.183259f, -0.423869f, 0.351207f, -0.004853f, -0.019413f, 0.092408f, 0.324681f,
|
|
-0.152191f, 0.178016f, 0.065732f, -0.232972f, 0.378630f, 0.297426f, 0.158452f,
|
|
0.350582f, 0.367294f, 0.208363f, 0.337013f, 0.197471f, 0.180141f, 0.118611f,
|
|
0.252717f, -0.341395f, 0.380871f, 0.371844f, -0.470753f, 0.325817f, -0.371130f,
|
|
-0.164881f, 0.243508f, -0.339240f, 0.317967f, 0.332134f, 0.007468f, -0.493614f,
|
|
-0.212962f, 0.116927f, 0.481186f, 0.131814f, -0.240196f, 0.134006f, 0.039985f,
|
|
0.279845f, -0.393019f, 0.261028f, 0.041267f, 0.462992f, -0.158128f, 0.132622f,
|
|
0.432028f, -0.397490f, 0.437229f, 0.187886f, -0.432163f, -0.199036f, 0.208172f,
|
|
-0.432649f, 0.082170f, -0.154117f, 0.120916f, -0.454258f, 0.371537f, 0.473489f,
|
|
0.468878f, 0.249652f, -0.369914f, 0.258263f, -0.475413f, -0.477876f, -0.176390f,
|
|
-0.011357f, 0.270407f, 0.183295f, -0.054097f, -0.226373f, 0.497124f, -0.073819f,
|
|
-0.048613f, -0.336376f, 0.294810f, 0.193682f, -0.279230f, -0.417619f, 0.180499f,
|
|
0.154511f, -0.226740f, 0.450864f, -0.348942f, -0.067665f, 0.443616f, -0.080273f,
|
|
0.138526f, -0.102406f, -0.225785f, 0.483978f, -0.090666f, 0.394099f, -0.270045f,
|
|
-0.286895f, -0.468866f, 0.151667f, -0.131474f, 0.364358f, -0.026790f, 0.468193f,
|
|
-0.314474f, 0.368623f, 0.276597f, 0.270922f, 0.344783f, 0.261024f, 0.126220f,
|
|
-0.368755f, -0.467474f, 0.420848f, 0.116650f, 0.296537f, -0.018478f, -0.382692f,
|
|
-0.374814f, 0.185565f, -0.069694f, -0.299475f, -0.008405f, -0.435791f, 0.081971f,
|
|
-0.231007f, 0.297559f, -0.189638f, -0.044780f, -0.488379f, -0.427553f, -0.107506f,
|
|
-0.020061f, 0.100021f, -0.208337f, 0.194982f, 0.360122f, 0.279851f, -0.460381f,
|
|
-0.019493f, -0.395070f, -0.257955f, 0.486663f, -0.357504f, -0.001112f, 0.118156f,
|
|
0.202465f, 0.059649f, -0.490229f, -0.173539f, 0.017712f, -0.412134f, -0.149373f,
|
|
-0.466797f, -0.421421f, -0.103077f, -0.367284f, 0.067541f, 0.189465f, 0.300587f,
|
|
-0.299850f, -0.332517f, -0.395432f, 0.136430f, 0.206476f, -0.468414f, 0.436212f,
|
|
-0.448029f, 0.041296f, 0.209061f, 0.370969f, 0.214087f, 0.301728f, -0.160550f,
|
|
0.314825f, -0.419885f, 0.394817f, 0.047592f, 0.317298f, -0.047682f, 0.143578f,
|
|
0.026403f, 0.231590f, -0.418370f, -0.439648f, -0.252897f, -0.340455f, 0.371784f,
|
|
-0.280786f, 0.475865f, -0.163104f, -0.317882f, 0.289699f, 0.158708f, -0.001804f,
|
|
0.055364f, 0.219202f, -0.271545f, 0.496334f, 0.474793f, 0.150326f, -0.300458f,
|
|
0.180228f, -0.427802f, -0.469348f, -0.242317f, -0.037377f, 0.368273f, 0.227169f,
|
|
0.242707f, -0.074507f, -0.154065f, -0.128961f, 0.487650f, -0.459891f, 0.367031f,
|
|
0.078675f, -0.061385f, 0.225258f, -0.013331f, 0.373423f, 0.400702f, -0.078279f,
|
|
-0.223172f, 0.092350f, 0.412363f, -0.289338f, 0.122967f, 0.131560f, 0.233113f,
|
|
-0.368432f, 0.215825f, 0.409033f, -0.320317f, -0.262457f, 0.471395f, -0.319023f,
|
|
0.354385f, -0.007722f, -0.252769f, 0.370750f, -0.054695f, 0.014817f, -0.140767f,
|
|
0.092951f, -0.336476f, -0.108918f, 0.469412f, -0.241867f, 0.156737f, -0.174810f,
|
|
0.273473f, -0.369126f, 0.469821f, -0.046210f, -0.263950f, -0.426503f, -0.330242f,
|
|
0.019774f, -0.162997f, 0.328883f, -0.069112f, -0.251286f, 0.117145f, 0.206777f,
|
|
-0.332958f, -0.332381f, -0.463329f, 0.236402f, 0.163805f, -0.025369f, 0.344170f,
|
|
0.305670f, 0.085354f, 0.368271f, -0.294159f, -0.388080f, -0.230250f, -0.442913f,
|
|
0.031170f, 0.436606f, -0.460656f, -0.377890f, -0.047801f, 0.433875f, -0.183844f,
|
|
0.007235f, -0.458427f, -0.351657f, 0.486630f, 0.465119f, -0.495060f, 0.451812f,
|
|
0.139120f, 0.367918f, -0.045260f, 0.015596f, -0.011153f, 0.166864f, -0.360349f,
|
|
-0.470026f, -0.192070f, 0.204681f, -0.298147f, 0.173432f, 0.469912f, -0.406099f,
|
|
0.172602f, -0.056250f, 0.368142f, -0.322850f, 0.192626f, 0.338115f, 0.444614f,
|
|
0.183248f, -0.002825f, 0.117847f, 0.368905f, 0.070610f, -0.469613f, 0.430949f,
|
|
0.189527f, 0.176513f, -0.284325f, 0.158885f, -0.106136f, 0.151233f, -0.393407f,
|
|
0.157845f, 0.499414f, -0.451788f, 0.477174f, -0.093092f, 0.370753f, 0.282385f,
|
|
0.067016f, 0.238449f, 0.378516f, -0.095860f, -0.172967f, 0.167593f, 0.307846f,
|
|
0.262285f, 0.297814f, -0.064417f, 0.317834f, -0.379791f, 0.044489f, -0.494241f,
|
|
-0.175414f, -0.133538f, -0.103827f, 0.195467f, -0.111442f, -0.051306f, -0.262456f,
|
|
-0.126748f, -0.272730f, -0.426804f, 0.103449f, 0.168213f, 0.119490f, -0.036506f,
|
|
-0.120214f, 0.363334f, 0.019082f, -0.020818f, -0.474358f, -0.158752f, -0.119804f,
|
|
-0.101177f, 0.080172f, 0.033603f, 0.107905f, 0.264883f, 0.312986f, 0.218123f,
|
|
0.455524f, -0.481767f, -0.304222f, -0.492437f, 0.147475f, 0.398031f, -0.256518f,
|
|
0.427035f, -0.439733f, 0.434436f, -0.148377f, -0.398579f, -0.014128f, -0.243223f,
|
|
-0.215127f, -0.192710f, 0.303026f, 0.039161f, -0.188692f, 0.110334f, 0.216151f,
|
|
-0.227376f, -0.086451f, -0.378114f, -0.318851f, 0.181118f, -0.318562f, 0.025163f,
|
|
0.209046f, -0.393123f, 0.067312f, -0.243437f, 0.462927f, -0.016454f, 0.305993f,
|
|
0.050227f, -0.456587f, 0.133151f, 0.451403f, 0.101612f, 0.319189f, 0.384206f,
|
|
-0.271920f, -0.287955f, 0.110981f, -0.088972f, 0.339861f, 0.400023f, -0.146579f,
|
|
-0.263129f, 0.280526f, -0.225194f, 0.322614f, -0.076262f, 0.167550f, -0.404465f,
|
|
0.123859f, -0.048232f, 0.086608f, -0.331986f, 0.236874f, 0.362797f, -0.283260f,
|
|
-0.404285f, -0.476361f, 0.141971f, 0.107094f, 0.046697f, -0.268053f, -0.109094f,
|
|
0.094476f, -0.003233f, 0.487786f, -0.363560f, 0.195145f, -0.095681f, -0.071800f,
|
|
0.217598f, 0.192436f, 0.491256f, -0.371606f, -0.395890f, 0.224339f, 0.078387f,
|
|
-0.225839f, -0.420581f, -0.414342f, 0.394191f, -0.308133f, -0.176628f, -0.273344f,
|
|
-0.145004f, -0.430576f, 0.019060f, -0.432387f, 0.300357f, -0.266288f, 0.040012f,
|
|
0.380079f, 0.150877f, 0.032958f, -0.175666f, -0.166998f, 0.169487f, 0.494139f,
|
|
0.161839f, 0.057783f, 0.230651f, -0.034794f, -0.439858f, 0.062297f, 0.457625f,
|
|
-0.324697f, 0.190005f, -0.299066f, 0.035828f, -0.403324f, -0.049629f, 0.256163f,
|
|
-0.152428f, 0.164912f, 0.295450f, 0.427178f, -0.265358f, -0.100684f, -0.347584f,
|
|
0.492483f, 0.427001f, 0.039957f, 0.342033f, 0.020958f, 0.123586f, -0.410876f,
|
|
0.255270f, -0.372287f, 0.326068f, 0.282028f, 0.208745f, -0.463840f, -0.196872f,
|
|
-0.236887f, -0.139864f, -0.412357f, 0.436958f, 0.053802f, -0.194476f, -0.103018f,
|
|
-0.052797f, 0.100594f, 0.015679f, 0.419392f, -0.003037f, 0.492158f, 0.351425f,
|
|
-0.291489f, 0.430595f, -0.383634f, 0.317450f, -0.119377f, 0.377974f, 0.368057f,
|
|
0.305925f, 0.290030f, -0.195321f, -0.419081f, -0.097020f, -0.326475f, 0.194951f,
|
|
-0.153900f, 0.475610f, 0.140972f, 0.322481f, -0.367475f, 0.362014f, 0.422757f,
|
|
-0.012938f, 0.106253f, 0.264810f, -0.325161f, 0.002566f, -0.101337f, -0.353626f,
|
|
-0.132466f, -0.431828f, -0.474188f, -0.364834f, 0.463115f, 0.049530f, 0.465822f,
|
|
-0.067502f, -0.188184f, 0.006142f, -0.060488f, -0.394335f, 0.140826f, -0.283962f,
|
|
0.119588f, 0.150201f, -0.347975f, -0.438650f, 0.280762f, -0.040200f, -0.441836f,
|
|
0.494866f, -0.442219f, 0.195035f, 0.483679f, -0.260820f, -0.357751f, -0.378615f,
|
|
-0.196725f, -0.398954f, 0.192161f, -0.437708f, 0.009422f, 0.496697f, 0.313970f,
|
|
0.115219f, -0.193746f, 0.123896f, 0.027041f, -0.073917f, -0.369290f, 0.386604f,
|
|
-0.050215f, -0.305377f, -0.132241f, -0.085870f, 0.327538f, 0.233614f, 0.269305f,
|
|
-0.488969f, -0.083846f, -0.018656f, -0.480808f, -0.240187f, 0.260290f, -0.362890f,
|
|
0.035310f, -0.284798f, -0.487879f, -0.258799f, 0.475874f, 0.301537f, 0.459577f,
|
|
-0.012146f, -0.390264f, 0.047959f, -0.045623f, 0.344357f, -0.401917f, -0.011759f,
|
|
-0.349951f, -0.175324f, 0.237357f, -0.023982f, -0.124112f, -0.105524f, -0.040553f,
|
|
0.285017f, 0.392085f, 0.455335f, 0.286903f, -0.184593f, 0.188135f, -0.062397f,
|
|
-0.245329f, 0.340872f, -0.461574f, 0.401762f, -0.038523f, 0.137201f, 0.159354f,
|
|
0.395118f, 0.136670f, 0.113934f, -0.433348f, 0.018408f, -0.349831f, 0.237434f,
|
|
0.012222f, 0.180228f, -0.458327f, -0.415208f, 0.216323f, -0.427916f, -0.428743f,
|
|
-0.487892f, 0.456501f, 0.237508f, -0.146749f, -0.203464f, -0.150297f, 0.274654f,
|
|
0.161371f, -0.314804f, -0.325891f, -0.401604f, 0.160303f, 0.264373f, -0.234954f,
|
|
-0.479055f, -0.417828f, 0.467860f, -0.204555f, 0.269223f, 0.124664f, -0.118060f,
|
|
-0.294313f, -0.378614f, 0.115013f, 0.274634f, 0.143904f, 0.030302f, -0.458049f,
|
|
0.468489f, 0.298714f, -0.207178f, 0.479970f, 0.101882f, 0.082423f, 0.248073f,
|
|
0.311770f, 0.156479f, -0.371904f, -0.161732f, 0.428084f, -0.275384f, -0.127833f,
|
|
-0.067923f, -0.060595f, 0.112940f, 0.443076f, -0.259307f, -0.378499f, -0.302530f,
|
|
0.386925f, 0.145811f, -0.214093f, 0.315947f, 0.361370f, 0.346514f, 0.418927f,
|
|
-0.247759f, 0.255042f, -0.039461f, 0.341999f, 0.228491f, 0.276447f, 0.156162f,
|
|
-0.322571f, 0.045027f, 0.484670f, 0.437388f, -0.456826f, -0.335185f, -0.368271f,
|
|
0.225980f, 0.317785f, -0.286489f, 0.005853f, 0.340703f, 0.232802f, 0.042237f,
|
|
0.090348f, 0.008361f, -0.202452f, 0.065022f, 0.188885f, 0.373323f, 0.136291f,
|
|
0.261122f, -0.339928f, -0.038443f, -0.490668f, -0.253321f, 0.226462f, 0.491810f,
|
|
-0.400822f, -0.098506f, 0.300071f, -0.295964f, 0.055085f, 0.233071f, 0.115985f,
|
|
-0.311975f, -0.144615f, 0.283792f, 0.054227f, -0.494770f, 0.260991f, -0.464689f,
|
|
0.245734f, -0.297519f, 0.458073f, -0.132059f, -0.173068f, -0.351112f, -0.194396f,
|
|
0.376651f, 0.496334f, -0.131690f, -0.051389f, 0.222071f, 0.386196f, 0.093044f,
|
|
-0.108474f, -0.087378f,
|
|
};
|
|
|
|
const std::vector<float> numpy_expected = {
|
|
0.007383f, -0.085425f, 0.011838f, 0.062971f, 0.043929f, 0.007666f, 0.008439f,
|
|
-0.046630f, -0.058420f, -0.034030f, 0.050607f, 0.002766f, 0.056086f, 0.071142f,
|
|
0.003148f, -0.008505f, 0.002715f, -0.076216f, -0.014847f, 0.068649f, 0.058922f,
|
|
-0.008740f, 0.021790f, -0.043732f, -0.082332f, -0.014314f, 0.041560f, 0.015328f,
|
|
0.045330f, 0.052070f, 0.014844f, 0.026025f, 0.007508f, -0.065677f, -0.006289f,
|
|
0.065917f, 0.036876f, 0.000431f, 0.013452f, -0.047478f, -0.076925f, -0.027326f,
|
|
0.047549f, 0.003660f, 0.052550f, 0.068205f, 0.015890f, 0.019385f, -0.002520f,
|
|
-0.068157f, -0.014357f, 0.059441f, 0.046273f, -0.015606f, 0.029188f, -0.047057f,
|
|
-0.067481f, -0.025480f, 0.048960f, 0.016361f, 0.055688f, 0.066174f, 0.022904f,
|
|
0.016228f, -0.017850f, -0.077436f, 0.015345f, 0.052739f, 0.056457f, -0.008167f,
|
|
-0.002618f, -0.035080f, -0.054646f, -0.047784f, 0.064118f, 0.021038f, 0.098352f,
|
|
0.061559f, 0.014207f, -0.006122f, -0.002099f, -0.067341f, -0.000756f, 0.057148f,
|
|
0.059963f, 0.001503f, 0.010144f, -0.032881f, -0.075191f, -0.032237f, 0.037420f,
|
|
0.001029f, 0.060923f, 0.060398f, 0.030673f, 0.012808f, -0.006748f, -0.047749f,
|
|
0.000415f, 0.060475f, 0.069737f, -0.008651f, 0.004705f, -0.012828f, -0.077261f,
|
|
-0.017083f, 0.051994f, 0.003326f, 0.062779f, 0.048019f, 0.008298f, -0.012594f,
|
|
-0.007749f, -0.055491f, -0.012014f, 0.053954f, 0.045582f, -0.010534f, 0.030729f,
|
|
-0.036889f, -0.063309f, -0.032229f, 0.049988f, 0.004904f, 0.070313f, 0.069882f,
|
|
0.033285f, 0.018283f, -0.009560f, -0.056328f, -0.007101f, 0.047559f, 0.067232f,
|
|
-0.013676f, 0.019708f, -0.032811f, -0.078113f, -0.040424f, 0.039800f, 0.003230f,
|
|
0.060881f, 0.069153f, 0.049097f, 0.012857f, -0.003914f, -0.063199f, 0.001035f,
|
|
0.065549f, 0.052037f, -0.002653f, -0.013828f, -0.048785f, -0.080286f, -0.041294f,
|
|
0.059457f, 0.014830f, 0.082938f, 0.054519f, 0.019383f, 0.025542f, 0.001185f,
|
|
-0.064716f, -0.015948f, 0.052071f, 0.032986f, -0.014907f, 0.051420f, -0.044499f,
|
|
-0.053381f, -0.017821f, 0.042237f, 0.002952f, 0.031287f, 0.084531f, 0.017001f,
|
|
-0.008584f, -0.010784f, -0.064312f, -0.024903f, 0.052547f, 0.063267f, -0.024236f,
|
|
0.046386f, -0.025896f, -0.068553f, -0.006001f, 0.044032f, 0.006031f, 0.043641f,
|
|
0.056054f, 0.016689f, 0.004116f, 0.014393f, -0.058293f, -0.004851f, 0.058634f,
|
|
0.027928f, 0.008397f, 0.033760f, -0.046834f, -0.072747f, -0.025939f, 0.024793f,
|
|
-0.008613f, 0.026162f, 0.088906f, 0.032530f, 0.011598f, 0.010774f, -0.087746f,
|
|
-0.002402f, 0.076286f, 0.052772f, -0.007808f, 0.042321f, -0.044525f, -0.074307f,
|
|
-0.020356f, 0.050978f, 0.005467f, 0.041848f, 0.067021f, -0.013176f, 0.016990f,
|
|
-0.018131f, -0.073032f, -0.014444f, 0.052988f, 0.066205f, -0.028847f, 0.041022f,
|
|
-0.028227f, -0.053479f, -0.012696f, 0.059475f, 0.020471f, 0.064025f, 0.053843f,
|
|
0.002226f, -0.009378f, -0.006675f, -0.061330f, -0.016546f, 0.045374f, 0.038021f,
|
|
-0.019298f, 0.049954f, -0.040340f, -0.044663f, -0.022905f, 0.044510f, 0.000977f,
|
|
0.038488f, 0.082866f, 0.025464f, -0.019278f, -0.009946f, -0.056392f, -0.003774f,
|
|
0.051014f, 0.046133f, -0.009736f, 0.021107f, -0.040785f, -0.057193f, -0.047951f,
|
|
0.055886f, 0.003465f, 0.078724f, 0.075681f, 0.040318f, 0.006164f, -0.009899f,
|
|
-0.067255f, -0.012504f, 0.061307f, 0.063530f, -0.014937f, 0.016265f, -0.035016f,
|
|
-0.074253f, -0.016603f, 0.052519f, 0.019856f, 0.065436f, 0.046476f, 0.014571f,
|
|
0.015569f, -0.005469f, -0.070110f, 0.003504f, 0.058781f, 0.054405f, -0.013541f,
|
|
0.035046f, -0.035151f, -0.061428f, -0.041955f, 0.064034f, 0.004731f, 0.079533f,
|
|
0.069533f, 0.006321f, 0.009739f, 0.009868f, -0.046759f, 0.003892f, 0.060610f,
|
|
0.044778f, 0.004380f, -0.013117f, -0.035925f, -0.088403f, -0.036423f, 0.046171f,
|
|
-0.005440f, 0.057470f, 0.064779f, 0.022364f, 0.000553f, 0.014907f, -0.062145f,
|
|
0.003694f, 0.063011f, 0.053007f, 0.000731f, 0.003884f, -0.046303f, -0.090317f,
|
|
-0.042867f, 0.037300f, -0.004294f, 0.044668f, 0.074411f, 0.030016f, 0.013970f,
|
|
0.002469f, -0.050964f, -0.006501f, 0.059326f, 0.037477f, -0.004060f, 0.006490f,
|
|
-0.050532f, -0.076494f, -0.042087f, 0.054995f, -0.000966f, 0.067863f, 0.072168f,
|
|
0.032234f, 0.017786f, -0.011112f, -0.075392f, -0.003143f, 0.052040f, 0.047606f,
|
|
-0.019149f, 0.046299f, -0.035092f, -0.041081f, -0.022936f, 0.065448f, 0.005120f,
|
|
0.065054f, 0.074648f, -0.008866f, -0.023949f, 0.005304f, -0.069631f, 0.009495f,
|
|
0.062978f, 0.044818f, 0.007730f, -0.001488f, -0.040640f, -0.066867f, -0.031884f,
|
|
0.052568f, 0.003658f, 0.061925f, 0.062329f, 0.004855f, -0.003895f, 0.114486f,
|
|
0.079661f, -0.115023f, 0.025315f, -0.000117f, -0.070439f, -0.009776f, 0.115430f,
|
|
0.047095f, -0.020249f, 0.001512f, -0.006185f, -0.036645f, 0.003067f, -0.048612f,
|
|
-0.035854f, 0.100041f, 0.077085f, -0.109820f, 0.015464f, -0.021206f, -0.063925f,
|
|
-0.009368f, 0.121258f, 0.055209f, -0.034103f, 0.008018f, -0.008480f, -0.026955f,
|
|
-0.004989f, -0.046626f, -0.017247f, 0.099377f, 0.074604f, -0.113369f, 0.007508f,
|
|
-0.004265f, -0.095276f, -0.026419f, 0.115797f, 0.092064f, -0.031276f, 0.007216f,
|
|
0.010462f, -0.008152f, -0.001692f, -0.045870f, -0.039668f, 0.090610f, 0.070008f,
|
|
-0.095893f, 0.036339f, -0.001674f, -0.076347f, -0.010227f, 0.120200f, 0.066155f,
|
|
-0.008440f, 0.010495f, -0.005206f, -0.039893f, -0.013893f, -0.045189f, -0.045747f,
|
|
0.101430f, 0.064898f, -0.104166f, 0.004913f, 0.012145f, -0.097956f, -0.028537f,
|
|
0.107966f, 0.079144f, -0.029408f, 0.001147f, 0.011118f, 0.002440f, 0.012128f,
|
|
-0.048582f, -0.051894f, 0.096293f, 0.093928f, -0.130731f, 0.027540f, -0.020008f,
|
|
-0.071251f, -0.015406f, 0.119832f, 0.084302f, -0.018117f, 0.013128f, 0.001765f,
|
|
-0.034212f, -0.008191f, -0.050701f, -0.026755f, 0.091574f, 0.058934f, -0.109406f,
|
|
0.031684f, 0.013173f, -0.073829f, -0.022562f, 0.122142f, 0.038862f, -0.031264f,
|
|
0.031566f, -0.011584f, -0.034398f, 0.001449f, -0.050027f, -0.034705f, 0.093171f,
|
|
0.092271f, -0.107576f, 0.039219f, -0.015123f, -0.054276f, -0.009520f, 0.109212f,
|
|
0.061468f, 0.000427f, -0.008970f, -0.002040f, -0.047295f, -0.001064f, -0.047281f,
|
|
-0.044790f, 0.096935f, 0.078937f, -0.092781f, 0.036182f, 0.018153f, -0.056738f,
|
|
-0.020583f, 0.109120f, 0.059436f, 0.001769f, -0.000911f, -0.003321f, -0.044719f,
|
|
0.010452f, -0.055386f, -0.059634f, 0.102150f, 0.071935f, -0.123576f, 0.025914f,
|
|
-0.014051f, -0.072845f, -0.011868f, 0.121021f, 0.055033f, -0.033752f, 0.019387f,
|
|
-0.010922f, -0.028995f, -0.004246f, -0.047819f, -0.017015f, 0.105048f, 0.077451f,
|
|
-0.111607f, 0.034564f, -0.009339f, -0.068584f, -0.006664f, 0.115148f, 0.050124f,
|
|
-0.015425f, 0.001799f, -0.009177f, -0.041747f, -0.004707f, -0.044247f, -0.034380f,
|
|
0.089478f, 0.095989f, -0.120383f, 0.017656f, -0.012592f, -0.064598f, -0.025977f,
|
|
0.108387f, 0.079686f, -0.023188f, -0.005437f, 0.007509f, -0.017324f, 0.016442f,
|
|
-0.047355f, -0.041292f, 0.088404f, 0.096468f, -0.106369f, 0.030468f, -0.002639f,
|
|
-0.071193f, -0.031953f, 0.110465f, 0.079732f, -0.007768f, -0.008830f, 0.009252f,
|
|
-0.039151f, 0.001257f, -0.033481f, -0.059676f, 0.097944f, 0.077180f, -0.121401f,
|
|
0.012640f, 0.007468f, -0.074098f, -0.035356f, 0.119044f, 0.065360f, -0.039596f,
|
|
0.019601f, 0.003659f, -0.011478f, 0.017890f, -0.054258f, -0.035000f, 0.084231f,
|
|
0.097995f, -0.112847f, 0.040351f, -0.010664f, -0.064514f, -0.018276f, 0.105636f,
|
|
0.089613f, 0.009693f, -0.009866f, 0.010465f, -0.039897f, -0.002313f, -0.049904f,
|
|
-0.058403f, 0.070641f, 0.071798f, -0.100271f, 0.039201f, -0.008508f, -0.083737f,
|
|
-0.027235f, 0.118331f, 0.089070f, -0.008805f, 0.014908f, 0.001467f, -0.036324f,
|
|
-0.016604f, -0.039218f, -0.046911f, 0.098816f, 0.076214f, -0.100256f, 0.028171f,
|
|
0.004775f, -0.077152f, -0.019596f, 0.110189f, 0.066893f, -0.010315f, -0.006468f,
|
|
0.000782f, -0.030826f, 0.003657f, -0.043033f, -0.054865f, 0.086425f, 0.084144f,
|
|
-0.118041f, 0.027978f, -0.008248f, -0.070637f, -0.022207f, 0.122238f, 0.085020f,
|
|
-0.018110f, 0.021525f, 0.001915f, -0.029623f, -0.005897f, -0.053233f, -0.031485f,
|
|
0.087198f, 0.088724f, -0.105858f, 0.035310f, 0.000956f, -0.058345f, -0.024886f,
|
|
0.109864f, 0.074043f, -0.003251f, -0.000581f, 0.002690f, -0.039096f, 0.008599f,
|
|
-0.051477f, -0.051773f, 0.091308f, 0.072021f, -0.109828f, 0.022452f, 0.006161f,
|
|
-0.081970f, -0.037871f, 0.119763f, 0.065898f, -0.032224f, 0.013211f, 0.000856f,
|
|
-0.025350f, 0.007624f, -0.039163f, -0.044446f, 0.118941f, 0.081076f, -0.135352f,
|
|
0.014605f, -0.005394f, -0.076476f, -0.015260f, 0.130009f, 0.054056f, -0.041352f,
|
|
0.026264f, -0.004701f, -0.031809f, -0.001052f, -0.052770f, -0.014205f, 0.108131f,
|
|
0.074984f, -0.117471f, 0.021073f, -0.014417f, -0.085287f, -0.012516f, 0.116266f,
|
|
0.060459f, -0.033259f, 0.000532f, -0.005881f, -0.027825f, -0.007172f, -0.034112f,
|
|
-0.028664f, 0.093271f, 0.087344f, -0.111297f, 0.019828f, 0.012828f, -0.075017f,
|
|
-0.041729f, 0.122014f, 0.080034f, -0.025663f, 0.017053f, 0.010779f, -0.029446f,
|
|
0.010609f, -0.048868f, -0.050144f, 0.108120f, 0.065707f, -0.121392f, 0.001176f,
|
|
0.013212f, -0.079736f, -0.031379f, 0.120274f, 0.045055f, -0.054592f, 0.025725f,
|
|
0.000130f, 0.001714f, 0.019266f, -0.056784f, -0.029767f,
|
|
};
|
|
|
|
const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O;
|
|
std::vector<float> o_ref(o_size);
|
|
auto ref_params = make_ref_params(prob, scale_s);
|
|
cpu_attention_ref(q_data, k_data, v_data, o_ref, ref_params);
|
|
|
|
EXPECT(allclose(o_ref, numpy_expected, 0.0001, 0.0001));
|
|
|
|
for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx)
|
|
{
|
|
auto&& solution = solutions[sol_idx];
|
|
std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl;
|
|
|
|
auto srcs = get_tile_headers_for_test();
|
|
srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)});
|
|
|
|
rtc::compile_options options;
|
|
options.kernel_name = "f";
|
|
auto kernel = rtc::compile_kernel(srcs, options);
|
|
|
|
auto [grid, block] = get_launch_dims(solution, prob);
|
|
|
|
rtc::buffer<half> o_host(o_size);
|
|
std::fill(o_host.begin(), o_host.end(), half(0.0f));
|
|
auto o_device = to_gpu(o_host);
|
|
const auto make_device_buff = [&](const std::vector<float>& data) {
|
|
rtc::buffer<half> host(data.size());
|
|
std::transform(
|
|
data.begin(), data.end(), host.begin(), [](float val) { return half(val); });
|
|
return to_gpu(host);
|
|
};
|
|
auto q_device = make_device_buff(q_data);
|
|
auto k_device = make_device_buff(k_data);
|
|
auto v_device = make_device_buff(v_data);
|
|
|
|
kernel.launch(nullptr, grid, block)(q_device.data(),
|
|
k_device.data(),
|
|
v_device.data(),
|
|
static_cast<half*>(nullptr),
|
|
o_device.data());
|
|
o_host = rtc::from_gpu(o_device);
|
|
std::vector<float> result(o_size);
|
|
std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) {
|
|
return static_cast<float>(v);
|
|
});
|
|
CHECK(allclose(result, o_ref, 0.0001, 0.0001));
|
|
}
|
|
}
|
|
|
|
TEST_CASE(test_fmha_fwd_4_8_128_256_32_64)
|
|
{
|
|
ck::host::device_fmha_fwd::Problem prob;
|
|
prob.M = 128; // seqlen_q
|
|
prob.N = 256; // seqlen_k
|
|
prob.K = 32; // hdim_q
|
|
prob.O = 64; // hdim_v
|
|
prob.batch = 4;
|
|
prob.nhead = 8;
|
|
prob.dtype = ck::host::DataType::Half;
|
|
prob.is_v_rowmajor = true;
|
|
prob.is_causal = false;
|
|
prob.has_bias = false;
|
|
|
|
const float scale_s = 1.0f / std::sqrt(static_cast<float>(prob.K));
|
|
|
|
auto solutions = prob.GetSolutions("gfx90a");
|
|
|
|
EXPECT(!solutions.empty());
|
|
|
|
const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K;
|
|
const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K;
|
|
const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O;
|
|
const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O;
|
|
|
|
std::mt19937 rng(42);
|
|
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
|
|
|
rtc::buffer<half> q_host(q_size), k_host(k_size), v_host(v_size);
|
|
std::vector<float> q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size);
|
|
|
|
auto fill_buffers = [&](auto& host, auto& ref) {
|
|
for(std::size_t i = 0; i < host.size(); ++i)
|
|
{
|
|
float val = dist(rng);
|
|
host[i] = half(val);
|
|
ref[i] = val;
|
|
}
|
|
};
|
|
fill_buffers(q_host, q_ref);
|
|
fill_buffers(k_host, k_ref);
|
|
fill_buffers(v_host, v_ref);
|
|
|
|
auto ref_params = make_ref_params(prob, scale_s);
|
|
cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params);
|
|
|
|
for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx)
|
|
{
|
|
auto&& solution = solutions[sol_idx];
|
|
std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl;
|
|
|
|
auto srcs = get_tile_headers_for_test();
|
|
srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)});
|
|
|
|
rtc::compile_options options;
|
|
options.kernel_name = "f";
|
|
auto kernel = rtc::compile_kernel(srcs, options);
|
|
|
|
auto [grid, block] = get_launch_dims(solution, prob);
|
|
|
|
rtc::buffer<half> o_host(o_size);
|
|
std::fill(o_host.begin(), o_host.end(), half(0.0f));
|
|
auto o_device = to_gpu(o_host);
|
|
auto q_device = to_gpu(q_host);
|
|
auto k_device = to_gpu(k_host);
|
|
auto v_device = to_gpu(v_host);
|
|
kernel.launch(nullptr, grid, block)(q_device.data(),
|
|
k_device.data(),
|
|
v_device.data(),
|
|
static_cast<half*>(nullptr),
|
|
o_device.data());
|
|
o_host = rtc::from_gpu(o_device);
|
|
std::vector<float> result(o_size);
|
|
std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) {
|
|
return static_cast<float>(v);
|
|
});
|
|
|
|
CHECK(allclose(o_ref, result, 0.0001, 0.0001));
|
|
}
|
|
}
|
|
|
|
TEST_CASE(test_fmha_fwd_with_bias)
|
|
{
|
|
ck::host::device_fmha_fwd::Problem prob;
|
|
prob.M = 64; // seqlen_q
|
|
prob.N = 128; // seqlen_k
|
|
prob.K = 32; // hdim_q
|
|
prob.O = 32; // hdim_v
|
|
prob.batch = 2;
|
|
prob.nhead = 4;
|
|
prob.dtype = ck::host::DataType::Half;
|
|
prob.is_v_rowmajor = true;
|
|
prob.is_causal = false;
|
|
prob.has_bias = true;
|
|
|
|
const float scale_s = 1.0f / std::sqrt(static_cast<float>(prob.K));
|
|
|
|
auto solutions = prob.GetSolutions("gfx90a");
|
|
|
|
EXPECT(!solutions.empty());
|
|
|
|
const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K;
|
|
const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K;
|
|
const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O;
|
|
const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O;
|
|
const std::size_t bias_size = prob.M * prob.N; // Only [M, N], broadcast across batch/nhead
|
|
|
|
std::mt19937 rng(43);
|
|
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
|
std::uniform_real_distribution<float> bias_dist(-0.1f, 0.1f);
|
|
|
|
rtc::buffer<half> q_host(q_size), k_host(k_size), v_host(v_size), bias_host(bias_size);
|
|
std::vector<float> q_ref(q_size), k_ref(k_size), v_ref(v_size), bias_ref(bias_size),
|
|
o_ref(o_size);
|
|
auto fill_buffers = [&](auto& host, auto& ref, auto& distribution) {
|
|
for(std::size_t i = 0; i < host.size(); ++i)
|
|
{
|
|
float val = distribution(rng);
|
|
host[i] = half(val);
|
|
ref[i] = val;
|
|
}
|
|
};
|
|
fill_buffers(q_host, q_ref, dist);
|
|
fill_buffers(k_host, k_ref, dist);
|
|
fill_buffers(v_host, v_ref, dist);
|
|
fill_buffers(bias_host, bias_ref, bias_dist);
|
|
|
|
auto ref_params = make_ref_params(prob, scale_s);
|
|
ref_params.bias_stride_m = prob.N;
|
|
ref_params.bias_stride_nhead = 0;
|
|
ref_params.bias_stride_batch = 0;
|
|
cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, &bias_ref, ref_params);
|
|
|
|
for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx)
|
|
{
|
|
auto&& solution = solutions[sol_idx];
|
|
std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl;
|
|
|
|
auto srcs = get_tile_headers_for_test();
|
|
srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)});
|
|
|
|
rtc::compile_options options;
|
|
options.kernel_name = "f";
|
|
auto kernel = rtc::compile_kernel(srcs, options);
|
|
|
|
auto [grid, block] = get_launch_dims(solution, prob);
|
|
|
|
rtc::buffer<half> o_host(o_size);
|
|
std::fill(o_host.begin(), o_host.end(), half(0.0f));
|
|
auto o_device = to_gpu(o_host);
|
|
auto q_device = to_gpu(q_host);
|
|
auto k_device = to_gpu(k_host);
|
|
auto v_device = to_gpu(v_host);
|
|
auto bias_device = to_gpu(bias_host);
|
|
kernel.launch(nullptr, grid, block)(
|
|
q_device.data(), k_device.data(), v_device.data(), bias_device.data(), o_device.data());
|
|
o_host = rtc::from_gpu(o_device);
|
|
std::vector<float> result(o_size);
|
|
std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) {
|
|
return static_cast<float>(v);
|
|
});
|
|
|
|
CHECK(allclose(result, o_ref, 0.0001, 0.0001));
|
|
}
|
|
}
|
|
|
|
TEST_CASE(sweep_fmha_fwd_solutions)
|
|
{
|
|
std::vector<std::size_t> seqlens_q{512, 1024, 2048, 4096};
|
|
std::vector<std::size_t> seqlens_k{512, 1024, 2048, 4096};
|
|
std::vector<std::size_t> hdims_q{32, 48, 64, 80, 96, 128, 192, 256};
|
|
std::vector<std::size_t> hdims_v{32, 48, 64, 80, 96, 128, 192, 256};
|
|
|
|
constexpr int batch_size = 2;
|
|
constexpr int num_heads = 4;
|
|
|
|
for(std::size_t M : seqlens_q)
|
|
{
|
|
for(std::size_t N : seqlens_k)
|
|
{
|
|
for(std::size_t K : hdims_q)
|
|
{
|
|
for(std::size_t O : hdims_v)
|
|
{
|
|
ck::host::device_fmha_fwd::Problem prob;
|
|
prob.M = M;
|
|
prob.N = N;
|
|
prob.K = K;
|
|
prob.O = O;
|
|
prob.batch = batch_size;
|
|
prob.nhead = num_heads;
|
|
prob.dtype = ck::host::DataType::Half;
|
|
prob.is_v_rowmajor = true;
|
|
prob.is_causal = false;
|
|
prob.has_bias = false;
|
|
|
|
auto solutions = prob.GetSolutions("gfx90a");
|
|
if(solutions.empty())
|
|
{
|
|
std::cout << "Config M=" << M << ", N=" << N << ", K=" << K << ", O=" << O
|
|
<< ": No solutions available" << std::endl;
|
|
}
|
|
CHECK(!solutions.empty());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|