now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -95,6 +95,7 @@ else()
-Wno-weak-vtables
-Wno-covered-switch-default
-Wno-unsafe-buffer-usage
-Wno-unused-lambda-capture
)
else()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")

View File

@@ -1,6 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include <array>
#include <cstring>
#include <functional>
@@ -9,11 +14,24 @@
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
@@ -91,12 +109,12 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0)
nhead_k = nhead;
@@ -143,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
stream_config stream_config{
ck_tile::stream_config stream_config{
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
@@ -207,53 +225,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
HostTensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
HostTensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
HostTensor<VDataType> v_host(
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
HostTensor<BiasDataType> bias_host(
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<BiasDataType> bias_host(
use_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
HostTensor<LSEDataType> lse_host(
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
HostTensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
if(init_method == 0)
{
ck_tile::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
}
else if(init_method == 1)
{
ck_tile::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
else if(init_method == 2)
{
ck_tile::utils::FillTrigValue<QDataType>{}(q_host);
ck_tile::utils::FillTrigValue<KDataType>{}(k_host);
ck_tile::utils::FillTrigValue<VDataType>{}(v_host);
ck_tile::utils::FillTrigValue<BiasDataType>{}(bias_host);
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -349,19 +371,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
const auto v_host_ref_lengths = std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_strides =
is_v_rowmajor ? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
is_v_rowmajor
? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
ck_tile::index_t nr = nhead / nhead_k;
@@ -386,7 +410,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
@@ -396,7 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(use_bias)
{
HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
@@ -406,43 +430,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
// real_seqlen_k]
reference_batched_elementwise<SMPLComputeDataType,
BiasDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
BiasDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
s_host_ref, bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask)
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
}
else if(mask.type == mask_enum::window_generic)
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
}
else
{
reference_batched_masking<SaccDataType>(
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
}
if(lse)
{
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref, lse_host_ref);
}
else
{
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref);
}
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host_ref, o_host_ref);
HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
// permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
@@ -450,7 +474,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = ck_tile::utils::check_err(
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass;
if(!cur_pass)
@@ -466,17 +490,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse)
{
HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
bool lse_pass = ck_tile::utils::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
bool lse_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= lse_pass;
if(!cur_pass)

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include <type_traits>
template <typename DataType>
struct FmhaFwdTypeConfig;
@@ -19,11 +20,11 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
@@ -34,11 +35,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
@@ -48,12 +49,12 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float; // TODO: fix me
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using BiasDataType = float; // TODO: fix me
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
};
@@ -64,11 +65,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf8_t>
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
};
@@ -107,7 +108,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
ck_tile::index_t mask_x)
{
constexpr bool is_v_rowmajor =
ck_tile::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -298,26 +299,26 @@ template <ck_tile::index_t HDim_,
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_fwd_(const stream_config&, fmha_fwd_args);
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
@@ -332,4 +333,4 @@ struct fmha_fwd_traits
bool has_lse;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const stream_config&);
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);

View File

@@ -103,12 +103,12 @@ using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::FmhaFwdEpilogue<FmhaFwdEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
@@ -117,7 +117,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -131,7 +131,7 @@ float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const stream_config& s){{
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;

View File

@@ -51,7 +51,7 @@ struct mask_info
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation

21
example/ck_tile/remod.py Normal file
View File

@@ -0,0 +1,21 @@
import pathlib
from pathlib import Path
import subprocess
import os
import copy
all_files = []
for p in sorted(Path("./").rglob("*")):
if p.suffix in ['.hpp', '.cpp']:
all_files.append(pathlib.PurePath(p))
# formatting
for x in all_files:
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
cmd = f'clang-format-12 -style=file -i {str(x)}'
#for xp in x.parents:
#print(get_file_base(x))
subprocess.Popen(cmd, shell=True)
#print(all_files)

View File

@@ -336,8 +336,8 @@ struct buffer_store<2>
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
@@ -468,9 +468,9 @@ struct buffer_store_if<2>
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
@@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
}
// buffer load i8
CK_TILE_DEVICE int8_t
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE int8x2_t
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE int8x4_t
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE int16_t
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE int16x2_t
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE int16x4_t
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE int32_t
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE int32x2_t
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE int32x4_t
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE fp16_t
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE fp16x2_t
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE fp16x4_t
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE float
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE fp32x2_t
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE fp32x4_t
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE void
llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata,
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
@@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
@@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE double
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
@@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<fp16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<bf16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<bf16x2_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<bf16x4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
@@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<0>{}],
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<bf16x4_t>()[number<1>{}],
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bf16_t),
@@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
}
// Direct loads from global to LDS.
CK_TILE_DEVICE void
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,

View File

@@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
CK_TILE_DEVICE void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck_tile

View File

@@ -9,6 +9,9 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace ck_tile {
@@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
} // namespace ck_tile

View File

@@ -9,13 +9,15 @@
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST __host__
#define CK_TILE_DEVICE __device__
#define CK_TILE_HOST_DEVICE __host__ __device__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
@@ -122,7 +124,7 @@
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
@@ -132,3 +134,7 @@
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif

View File

@@ -21,7 +21,12 @@ struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
@@ -44,18 +49,24 @@ struct array
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
template <typename Y>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = c;
}
template <typename ArrayType>
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
{
static_assert(ArrayType::size() == size(), "wrong! size not the same");
for(auto i = 0; i < size(); i++)
data[i] = o.data[i];
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
@@ -147,10 +158,10 @@ struct vector_traits<array<T, N>>
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
{
using value_type = remove_cvref_t<T>;
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
}
// make empty array

View File

@@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
@@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
}()
#endif
} // namespace ck_tile

View File

@@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
for(index_t i = 0; i < Seq::size(); ++i)
{
result = f(result, Seq::get(i));
result = f(result, Seq::at(i));
}
return result;
@@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag || f(Seq::get(i));
flag = flag || f(Seq::at(i));
}
return flag;
@@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag && f(Seq::get(i));
flag = flag && f(Seq::at(i));
}
return flag;
@@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// template <index_t... Is>
// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple<number<Is>...>)
// {
// return sequence<Is...>{};
// }
template <class... T>
struct tuple;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
{
return sequence<Is...>{};
}
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>

View File

@@ -139,6 +139,26 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
return !(a == b);
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
@@ -237,21 +257,21 @@ template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
@@ -490,58 +510,58 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) \
_Pragma("clang diagnostic pop")
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif

View File

@@ -4,44 +4,36 @@
#pragma once
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
CK_TILE_HOST_DEVICE \
bool operator==(const type_& x, const type_& y) \
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator!=(const type_& x, const type_& y) \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<(const type_& x, const type_& y) \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator<=(const type_& x, const type_& y) \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>(const type_& x, const type_& y) \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
bool operator>=(const type_& x, const type_& y) \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
CK_TILE_HOST_DEVICE \
type_ operator+(const type_& x, const type_& y) \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x) \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
@@ -49,66 +41,55 @@
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator-(const type_& x, const type_& y) \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator*(const type_& x, const type_& y) \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_ operator/(const type_& x, const type_& y) \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
CK_TILE_HOST_DEVICE \
type_& operator+=(type_& x, const type_& y) \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator-=(type_& x, const type_& y) \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator*=(type_& x, const type_& y) \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator/=(type_& x, const type_& y) \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator++(type_& x) \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_& operator--(type_& x) \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
CK_TILE_HOST_DEVICE \
type_ operator++(type_& x, int) \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
CK_TILE_HOST_DEVICE \
type_ operator--(type_& x, int) \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \

View File

@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x);
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
return float_to_bf16_truc_raw(f);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
{
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
return u.fp32;
}
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
{
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
}
};
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
// convert to bitwise
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
// check if x is 0.0
if(x_bitwise == 0)
return 0;
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
}
else
{
return signed_inf;
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
return __builtin_bit_cast(
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
}
template <typename X, typename Y, bool negative_zero_nan>
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr X nan_code = 0x80;
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x == 0)
if(x_raw == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> in_mant;
uint32_t sign = x_raw >> (in_exp + in_mant);
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
int exponent = (x_raw & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if constexpr(negative_zero_nan)
{
if(x == nan_code)
if(x_raw == nan_code)
return NaN;
}
else
{
if(x == nan_code)
if(x_raw == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval = x_raw;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
};
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
// math
CK_TILE_HOST_DEVICE

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include <hip/hip_fp16.h>
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x);
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
@@ -46,6 +53,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
@@ -61,6 +72,10 @@ struct alignas(2) half_t
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to double
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x)
{
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
CK_TILE_HOST_DEVICE
half_t double_to_fp16(const double& x) { return half_t{x}; }
// limits
template <class T>
struct numeric_limits;
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
};
// arithmetic
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
return y;
}
CK_TILE_HOST_DEVICE
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
return y;
}
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }

View File

@@ -14,8 +14,9 @@ struct constant
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <typename T, T v>

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace ck_tile {
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - clz(x);
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
@@ -275,7 +276,7 @@ struct log2e<float>
};
template <typename T = double>
inline constexpr T log2e_v = log2e<T>::value;
constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
@@ -298,16 +299,32 @@ bool isnan(const float& x)
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
} // namespace ck_tile

View File

@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)

View File

@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bfp16
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));

View File

@@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
@@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVecType>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVec>(
number<in_offset / vec_length_in>{});
});
// transpose
@@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVecType>(
number<out_offset / sizeof(OutVecType)>{},
out_vectors[i].template get_as<OutVecType>()[I0]);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}

View File

@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \

View File

@@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_length();
return Base::get_top_dimension_lengths();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const

View File

@@ -296,7 +296,8 @@ CK_TILE_HOST_DEVICE constexpr auto
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
constexpr auto h_minor_lengths =
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();
@@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
@@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
@@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
@@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"

View File

@@ -7,14 +7,14 @@
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else

View File

@@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
namespace impl {
template <typename T>
struct is_static_impl
{
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
@@ -69,6 +48,36 @@ struct nonesuch
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
namespace impl {
template <typename T>
using has_is_static = decltype(T::is_static());
template <typename T>
struct is_static_impl
{
static constexpr bool value = []() {
if constexpr(is_detected<has_is_static, T>{})
return T::is_static();
else
return std::is_arithmetic<T>::value;
}();
};
} // namespace impl
template <typename T>
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
template <typename T>
inline constexpr bool is_static_v = is_static<T>::value;
// TODO: deprecate this
template <typename T>
using is_known_at_compile_time = is_static<T>;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template <
typename PY,

View File

@@ -40,7 +40,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -98,7 +98,7 @@ template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -157,7 +157,7 @@ template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
bool>::type CK_TILE_HOST
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
@@ -182,7 +182,7 @@ check_err(const Range& out,
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
@@ -220,11 +220,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
@@ -270,12 +270,12 @@ template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
@@ -323,12 +323,12 @@ template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{

View File

@@ -3,13 +3,14 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <sstream>
#include <stdexcept>
#include <hip/hip_runtime.h>
namespace ck_tile {
// To be removed, which really does not tell the location of failed HIP functional call
inline void hip_check_error(hipError_t x)
CK_TILE_HOST void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
{

View File

@@ -18,11 +18,11 @@
namespace ck_tile {
template <typename Range>
std::ostream& LogRange(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
CK_TILE_HOST std::ostream& LogRange(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{
bool first = true;
for(auto&& v : range)
@@ -37,11 +37,11 @@ std::ostream& LogRange(std::ostream& os,
}
template <typename T, typename Range>
std::ostream& LogRangeAsType(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
CK_TILE_HOST std::ostream& LogRangeAsType(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{
bool first = true;
for(auto&& v : range)
@@ -56,13 +56,13 @@ std::ostream& LogRangeAsType(std::ostream& os,
}
template <typename F, typename T, std::size_t... Is>
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
{
return f(std::get<Is>(args)...);
}
template <typename F, typename T>
auto call_f_unpack_args(F f, T args)
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
@@ -70,13 +70,13 @@ auto call_f_unpack_args(F f, T args)
}
template <typename F, typename T, std::size_t... Is>
auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
{
return F(std::get<Is>(args)...);
}
template <typename F, typename T>
auto construct_f_unpack_args(F, T args)
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
@@ -87,7 +87,19 @@ struct HostTensorDescriptor
{
HostTensorDescriptor() = default;
void CalculateStrides();
void CalculateStrides()
{
mStrides.clear();
mStrides.resize(mLens.size(), 0);
if(mStrides.empty())
return;
mStrides.back() = 1;
std::partial_sum(mLens.rbegin(),
mLens.rend() - 1,
mStrides.rbegin() + 1,
std::multiplies<std::size_t>());
}
template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
@@ -123,12 +135,28 @@ struct HostTensorDescriptor
{
}
std::size_t get_num_of_dimension() const;
std::size_t get_element_size() const;
std::size_t get_element_space_size() const;
std::size_t get_num_of_dimension() const { return mLens.size(); }
std::size_t get_element_size() const
{
assert(mLens.size() == mStrides.size());
return std::accumulate(
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t get_element_space_size() const
{
std::size_t space = 1;
for(std::size_t i = 0; i < mLens.size(); ++i)
{
if(mLens[i] == 0)
continue;
const std::vector<std::size_t>& get_lengths() const;
const std::vector<std::size_t>& GetStrides() const;
space += (mLens[i] - 1) * mStrides[i];
}
return space;
}
const std::vector<std::size_t>& get_lengths() const { return mLens; }
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
@@ -151,8 +179,8 @@ struct HostTensorDescriptor
};
template <typename New2Old>
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
const New2Old& new2old)
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(
const HostTensorDescriptor& a, const New2Old& new2old)
{
std::vector<std::size_t> new_lengths(a.get_num_of_dimension());
std::vector<std::size_t> new_strides(a.get_num_of_dimension());
@@ -238,7 +266,7 @@ struct ParallelTensorFunctor
};
template <typename F, typename... Xs>
auto make_ParallelTensorFunctor(F f, Xs... xs)
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
{
return ParallelTensorFunctor<F, Xs...>(f, xs...);
}

View File

@@ -20,12 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
}
template <typename... Args, typename F>
float launch_and_time_kernel(const stream_config& s,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
@@ -75,13 +75,13 @@ float launch_and_time_kernel(const stream_config& s,
}
template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const stream_config& s,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
@@ -151,12 +151,12 @@ template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
CK_TILE_HOST float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;

View File

@@ -10,7 +10,6 @@
// ranges implementation are not intented to be used by user
// TODO: do we need this?
namespace ck_tile {
namespace ranges {
template <typename T>
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
@@ -21,8 +20,7 @@ using iter_reference_t = decltype(*std::declval<T&>());
template <typename T>
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
//.........................
namespace ranges {
template <typename R>
using iterator_t = decltype(std::begin(std::declval<R&>()));

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename BinaryElementOp = ck_tile::plus<AccDataType>>
void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
{
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename CDataType, typename MaskingType>
void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
{
const int M = c_b_m_n.mDesc.get_lengths()[1];
const int N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename CompDataType, typename BDataType>
void reference_batched_softmax(
CK_TILE_HOST void reference_batched_softmax(
const HostTensor<ADataType>& a_b_m_n,
HostTensor<BDataType>& b_b_m_n,
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];

View File

@@ -10,25 +10,25 @@
namespace ck_tile {
template <typename T>
void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
{
int GemmM = in_mtx_host_ref.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];

View File

@@ -10,12 +10,13 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_softmax(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m_n)
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
AccDataType v_max = ck_tile::NumericLimits<ADataType>::Lowest();
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)

View File

@@ -575,9 +575,8 @@ struct FmhaFwdKernel
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables is not supported. Remove
/// following copy capture of the 'i_nhead'
/// if compiled in C++20
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -189,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -208,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -346,12 +347,15 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -360,7 +364,7 @@ struct BlockFmhaPipelineQRKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -375,7 +379,7 @@ struct BlockFmhaPipelineQRKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -231,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
@@ -251,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -389,12 +390,15 @@ struct BlockFmhaPipelineQRKSVSAsync
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -403,7 +407,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -454,7 +458,7 @@ struct BlockFmhaPipelineQRKSVSAsync
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -181,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -329,12 +330,15 @@ struct BlockFmhaPipelineQRKSVSFp8
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -343,7 +347,7 @@ struct BlockFmhaPipelineQRKSVSFp8
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -358,7 +362,7 @@ struct BlockFmhaPipelineQRKSVSFp8
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -338,12 +338,15 @@ struct BlockFmhaPipelineQSKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -352,7 +355,7 @@ struct BlockFmhaPipelineQSKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -367,7 +370,7 @@ struct BlockFmhaPipelineQSKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -9,6 +9,11 @@
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
@@ -97,9 +102,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -222,9 +226,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -918,12 +921,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
{
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {