now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -1,6 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include <array>
#include <cstring>
#include <functional>
@@ -9,11 +14,24 @@
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
@@ -91,12 +109,12 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0)
nhead_k = nhead;
@@ -143,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
stream_config stream_config{
ck_tile::stream_config stream_config{
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
@@ -207,53 +225,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
HostTensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
HostTensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
HostTensor<VDataType> v_host(
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
HostTensor<BiasDataType> bias_host(
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<BiasDataType> bias_host(
use_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
HostTensor<LSEDataType> lse_host(
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
HostTensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
if(init_method == 0)
{
ck_tile::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
}
else if(init_method == 1)
{
ck_tile::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
else if(init_method == 2)
{
ck_tile::utils::FillTrigValue<QDataType>{}(q_host);
ck_tile::utils::FillTrigValue<KDataType>{}(k_host);
ck_tile::utils::FillTrigValue<VDataType>{}(v_host);
ck_tile::utils::FillTrigValue<BiasDataType>{}(bias_host);
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -349,19 +371,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
const auto v_host_ref_lengths = std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_strides =
is_v_rowmajor ? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
is_v_rowmajor
? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
ck_tile::index_t nr = nhead / nhead_k;
@@ -386,7 +410,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
@@ -396,7 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(use_bias)
{
HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
@@ -406,43 +430,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
// real_seqlen_k]
reference_batched_elementwise<SMPLComputeDataType,
BiasDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
BiasDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
s_host_ref, bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask)
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
}
else if(mask.type == mask_enum::window_generic)
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
}
else
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
}
if(lse)
{
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref, lse_host_ref);
}
else
{
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref);
}
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host_ref, o_host_ref);
HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
// permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
@@ -450,7 +474,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = ck_tile::utils::check_err(
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass;
if(!cur_pass)
@@ -466,17 +490,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse)
{
HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
bool lse_pass = ck_tile::utils::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
bool lse_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= lse_pass;
if(!cur_pass)

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include <type_traits>
template <typename DataType>
struct FmhaFwdTypeConfig;
@@ -19,11 +20,11 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
@@ -34,11 +35,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
@@ -48,12 +49,12 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float; // TODO: fix me
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using BiasDataType = float; // TODO: fix me
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
};
@@ -64,11 +65,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf8_t>
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
};
@@ -107,7 +108,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
ck_tile::index_t mask_x)
{
constexpr bool is_v_rowmajor =
ck_tile::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -298,26 +299,26 @@ template <ck_tile::index_t HDim_,
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_fwd_(const stream_config&, fmha_fwd_args);
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
@@ -332,4 +333,4 @@ struct fmha_fwd_traits
bool has_lse;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const stream_config&);
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);

View File

@@ -103,12 +103,12 @@ using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::FmhaFwdEpilogue<FmhaFwdEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
@@ -117,7 +117,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -131,7 +131,7 @@ float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const stream_config& s){{
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;

View File

@@ -51,7 +51,7 @@ struct mask_info
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation

21
example/ck_tile/remod.py Normal file
View File

@@ -0,0 +1,21 @@
import pathlib
from pathlib import Path
import subprocess
import os
import copy
all_files = []
for p in sorted(Path("./").rglob("*")):
if p.suffix in ['.hpp', '.cpp']:
all_files.append(pathlib.PurePath(p))
# formatting
for x in all_files:
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
cmd = f'clang-format-12 -style=file -i {str(x)}'
#for xp in x.parents:
#print(get_file_base(x))
subprocess.Popen(cmd, shell=True)
#print(all_files)