mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
now can build
This commit is contained in:
@@ -95,6 +95,7 @@ else()
|
||||
-Wno-weak-vtables
|
||||
-Wno-covered-switch-default
|
||||
-Wno-unsafe-buffer-usage
|
||||
-Wno-unused-lambda-capture
|
||||
)
|
||||
else()
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
@@ -9,11 +14,24 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -91,12 +109,12 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
if(nhead_k == 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
@@ -143,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
|
||||
stream_config stream_config{
|
||||
ck_tile::stream_config stream_config{
|
||||
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
@@ -207,53 +225,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
|
||||
HostTensor<QDataType> q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
HostTensor<KDataType> k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
HostTensor<VDataType> v_host(
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
HostTensor<BiasDataType> bias_host(
|
||||
use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
use_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
HostTensor<LSEDataType> lse_host(
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
HostTensor<ODataType> o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::utils::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::utils::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::utils::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::utils::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::utils::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::utils::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::utils::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::utils::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::utils::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
|
||||
DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -349,19 +371,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
const auto v_host_ref_lengths = std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
|
||||
const auto v_host_ref_lengths =
|
||||
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
|
||||
const auto v_host_ref_strides =
|
||||
is_v_rowmajor ? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
|
||||
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
|
||||
is_v_rowmajor
|
||||
? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
|
||||
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
|
||||
|
||||
HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
|
||||
HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
||||
|
||||
HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
@@ -386,7 +410,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
@@ -396,7 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(use_bias)
|
||||
{
|
||||
HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
|
||||
@@ -406,43 +430,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
reference_batched_elementwise<SMPLComputeDataType,
|
||||
BiasDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
BiasDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
reference_batched_masking<SaccDataType>(
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
reference_batched_masking<SaccDataType>(
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
reference_batched_masking<SaccDataType>(
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
if(lse)
|
||||
{
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref);
|
||||
}
|
||||
|
||||
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref);
|
||||
|
||||
HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
@@ -450,7 +474,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
bool cur_pass = ck_tile::utils::check_err(
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
@@ -466,17 +490,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(lse)
|
||||
{
|
||||
HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
bool lse_pass = ck_tile::utils::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= lse_pass;
|
||||
if(!cur_pass)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
@@ -19,11 +20,11 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
@@ -34,11 +35,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
@@ -48,12 +49,12 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float; // TODO: fix me
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using BiasDataType = float; // TODO: fix me
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
@@ -64,11 +65,11 @@ struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
@@ -107,7 +108,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
|
||||
ck_tile::index_t mask_x)
|
||||
{
|
||||
constexpr bool is_v_rowmajor =
|
||||
ck_tile::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -298,26 +299,26 @@ template <ck_tile::index_t HDim_,
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const stream_config&, fmha_fwd_args);
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
@@ -332,4 +333,4 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const stream_config&);
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -103,12 +103,12 @@ using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::FmhaFwdEpilogue<FmhaFwdEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
@@ -117,7 +117,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -131,7 +131,7 @@ float fmha_fwd_<trait_{F_idx}>(const stream_config& s, fmha_fwd_args a)
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const stream_config& s){{
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
|
||||
@@ -51,7 +51,7 @@ struct mask_info
|
||||
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
|
||||
21
example/ck_tile/remod.py
Normal file
21
example/ck_tile/remod.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import os
|
||||
import copy
|
||||
|
||||
all_files = []
|
||||
for p in sorted(Path("./").rglob("*")):
|
||||
if p.suffix in ['.hpp', '.cpp']:
|
||||
all_files.append(pathlib.PurePath(p))
|
||||
|
||||
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
|
||||
cmd = f'clang-format-12 -style=file -i {str(x)}'
|
||||
#for xp in x.parents:
|
||||
#print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
#print(all_files)
|
||||
@@ -336,8 +336,8 @@ struct buffer_store<2>
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
static_assert(sizeof(T) == 2);
|
||||
using mbuf_t = short;
|
||||
asm volatile(
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
@@ -468,9 +468,9 @@ struct buffer_store_if<2>
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
@@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
|
||||
}
|
||||
|
||||
// buffer load i8
|
||||
CK_TILE_DEVICE int8_t
|
||||
CK_TILE_DEVICE_EXTERN int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
CK_TILE_DEVICE int8x2_t
|
||||
CK_TILE_DEVICE_EXTERN int8x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
CK_TILE_DEVICE int8x4_t
|
||||
CK_TILE_DEVICE_EXTERN int8x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
// buffer load i16
|
||||
CK_TILE_DEVICE int16_t
|
||||
CK_TILE_DEVICE_EXTERN int16_t
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
||||
|
||||
CK_TILE_DEVICE int16x2_t
|
||||
CK_TILE_DEVICE_EXTERN int16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
||||
|
||||
CK_TILE_DEVICE int16x4_t
|
||||
CK_TILE_DEVICE_EXTERN int16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
||||
|
||||
// buffer load i32
|
||||
CK_TILE_DEVICE int32_t
|
||||
CK_TILE_DEVICE_EXTERN int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
|
||||
CK_TILE_DEVICE int32x2_t
|
||||
CK_TILE_DEVICE_EXTERN int32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
CK_TILE_DEVICE int32x4_t
|
||||
CK_TILE_DEVICE_EXTERN int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
// buffer load fp16
|
||||
CK_TILE_DEVICE fp16_t
|
||||
CK_TILE_DEVICE_EXTERN _Float16
|
||||
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
CK_TILE_DEVICE fp16x4_t
|
||||
CK_TILE_DEVICE_EXTERN fp16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// buffer load fp32
|
||||
CK_TILE_DEVICE float
|
||||
CK_TILE_DEVICE_EXTERN float
|
||||
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
|
||||
CK_TILE_DEVICE fp32x2_t
|
||||
CK_TILE_DEVICE_EXTERN fp32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
|
||||
CK_TILE_DEVICE fp32x4_t
|
||||
CK_TILE_DEVICE_EXTERN fp32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
|
||||
// buffer store i8
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
// buffer store i16
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
// buffer store i32
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// buffer store fp16
|
||||
CK_TILE_DEVICE void
|
||||
llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata,
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
|
||||
// buffer store fp32
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
// buffer atomic-add fp16
|
||||
CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
@@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer atomic-max fp64
|
||||
CK_TILE_DEVICE double
|
||||
CK_TILE_DEVICE_EXTERN double
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
@@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<fp16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<bf16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<bf16x4_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array<T, N> src_thread_data,
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(
|
||||
src_thread_data.template get_as<bf16x4_t>()[number<0>{}],
|
||||
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(
|
||||
src_thread_data.template get_as<bf16x4_t>()[number<1>{}],
|
||||
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(bf16_t),
|
||||
@@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array<T, N>& src_thread_data,
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
|
||||
@@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop()
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\
|
||||
s_nop 0 \n \
|
||||
" ::);
|
||||
#else
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,13 +9,15 @@
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST __host__
|
||||
#define CK_TILE_DEVICE __device__
|
||||
#define CK_TILE_HOST_DEVICE __host__ __device__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
@@ -122,7 +124,7 @@
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
@@ -132,3 +134,7 @@
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
@@ -21,7 +21,12 @@ struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
// TODO: do we need this?
|
||||
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
|
||||
// union {
|
||||
value_type data[N];
|
||||
// bulk_type __content;
|
||||
//};
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
@@ -44,18 +49,24 @@ struct array
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
|
||||
template <typename Y>
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = c;
|
||||
}
|
||||
template <typename ArrayType>
|
||||
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
|
||||
{
|
||||
static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = o.data[i];
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
@@ -147,10 +158,10 @@ struct vector_traits<array<T, N>>
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs)
|
||||
{
|
||||
using value_type = remove_cvref_t<T>;
|
||||
return array<value_type, sizeof...(Ts) + 1>{std::forward<T>(x), std::forward<Ts>(xs)...};
|
||||
return array<value_type, sizeof...(Ts)>{std::forward<Ts>(xs)...};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
|
||||
@@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
@@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
}()
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
result = f(result, Seq::get(i));
|
||||
result = f(result, Seq::at(i));
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
flag = flag || f(Seq::get(i));
|
||||
flag = flag || f(Seq::at(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
@@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
|
||||
|
||||
for(index_t i = 0; i < Seq::size(); ++i)
|
||||
{
|
||||
flag = flag && f(Seq::get(i));
|
||||
flag = flag && f(Seq::at(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
@@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// template <index_t... Is>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple<number<Is>...>)
|
||||
// {
|
||||
// return sequence<Is...>{};
|
||||
// }
|
||||
template <class... T>
|
||||
struct tuple;
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
|
||||
{
|
||||
return sequence<Is...>{};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
|
||||
|
||||
@@ -139,6 +139,26 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
}
|
||||
});
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
@@ -237,21 +257,21 @@ template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{});
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
@@ -490,58 +510,58 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
@@ -4,44 +4,36 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator==(const type_& x, const type_& y) \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator!=(const type_& x, const type_& y) \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<(const type_& x, const type_& y) \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator<=(const type_& x, const type_& y) \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>(const type_& x, const type_& y) \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
bool operator>=(const type_& x, const type_& y) \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator+(const type_& x, const type_& y) \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x) \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
@@ -49,66 +41,55 @@
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator-(const type_& x, const type_& y) \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator*(const type_& x, const type_& y) \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator/(const type_& x, const type_& y) \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator+=(type_& x, const type_& y) \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator-=(type_& x, const type_& y) \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator*=(type_& x, const type_& y) \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator/=(type_& x, const type_& y) \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator++(type_& x) \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_& operator--(type_& x) \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator++(type_& x, int) \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
CK_TILE_HOST_DEVICE \
|
||||
type_ operator--(type_& x, int) \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
|
||||
@@ -24,9 +24,16 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
@@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
@@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
@@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
@@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x)
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
@@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
|
||||
{
|
||||
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
@@ -240,7 +274,7 @@ struct numeric_limits<bfloat16_t>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = 0x80;
|
||||
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
@@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
|
||||
// check if x is 0.0
|
||||
if(x_bitwise == 0)
|
||||
return 0;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of implict 1
|
||||
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
||||
@@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
|
||||
}
|
||||
}
|
||||
|
||||
// check if x is 0.0 or -0.0
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
|
||||
return __builtin_bit_cast(
|
||||
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
|
||||
mantissa &= (1 << out_mant) - 1;
|
||||
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
|
||||
return __builtin_bit_cast(Y,
|
||||
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
|
||||
(out_exponent << out_mant) | mantissa));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, bool negative_zero_nan>
|
||||
@@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
|
||||
|
||||
// prepare the codes
|
||||
constexpr X nan_code = 0x80;
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
|
||||
@@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
|
||||
|
||||
// check if x is 0.0
|
||||
if(x == 0)
|
||||
if(x_raw == 0)
|
||||
return static_cast<Y>(0);
|
||||
|
||||
// unpack the input
|
||||
uint32_t sign = x >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x & ((1 << in_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> in_mant;
|
||||
uint32_t sign = x_raw >> (in_exp + in_mant);
|
||||
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
|
||||
int exponent = (x_raw & 0x7F) >> in_mant;
|
||||
|
||||
constexpr int exp_low_cutoff =
|
||||
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
@@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if constexpr(negative_zero_nan)
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return NaN;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == nan_code)
|
||||
if(x_raw == nan_code)
|
||||
return Neg0;
|
||||
if(exponent == ((1 << in_exp) - 1))
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
@@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x;
|
||||
retval = x_raw;
|
||||
retval <<= 8;
|
||||
return *(reinterpret_cast<const Y*>(&retval));
|
||||
}
|
||||
@@ -700,8 +704,8 @@ struct numeric_limits<bf8_t>
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
};
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
@@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
@@ -46,6 +53,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
@@ -61,6 +72,10 @@ struct alignas(2) half_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
@@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
@@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t double_to_fp16(const double& x) { return half_t{x}; }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
@@ -156,94 +187,94 @@ struct numeric_utils<half_t>
|
||||
};
|
||||
|
||||
// arithmetic
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -251,7 +282,7 @@ half_t operator++(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
@@ -259,6 +290,8 @@ half_t operator--(half_t& x, int)
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
|
||||
@@ -14,8 +14,9 @@ struct constant
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
@@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - clz(x);
|
||||
return 31 - __builtin_clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
@@ -275,7 +276,7 @@ struct log2e<float>
|
||||
};
|
||||
|
||||
template <typename T = double>
|
||||
inline constexpr T log2e_v = log2e<T>::value;
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -298,16 +299,32 @@ bool isnan(const float& x)
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __expf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
|
||||
@@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bfp16
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
@@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
|
||||
@@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
using InVecType = typename InVec::type;
|
||||
using OutVecType = typename OutVec::type;
|
||||
// using InVec = typename InVec::type;
|
||||
// using OutVec = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
@@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVecType>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
});
|
||||
|
||||
// transpose
|
||||
@@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVecType>(
|
||||
number<out_offset / sizeof(OutVecType)>{},
|
||||
out_vectors[i].template get_as<OutVecType>()[I0]);
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
|
||||
@@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_length();
|
||||
return Base::get_top_dimension_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
|
||||
@@ -296,7 +296,8 @@ CK_TILE_HOST_DEVICE constexpr auto
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
constexpr auto h_minor_lengths =
|
||||
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
@@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
@@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
@@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
@@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
|
||||
@@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = std::is_arithmetic<T>::v ? false : T::is_static();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
@@ -69,6 +48,36 @@ struct nonesuch
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
using has_is_static = decltype(T::is_static());
|
||||
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_is_static, T>{})
|
||||
return T::is_static();
|
||||
else
|
||||
return std::is_arithmetic<T>::value;
|
||||
}();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
|
||||
@@ -40,7 +40,7 @@ typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_floating_point_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
bool>::type
|
||||
bool>::type CK_TILE_HOST
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
@@ -98,7 +98,7 @@ template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
|
||||
bool>::type
|
||||
bool>::type CK_TILE_HOST
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
@@ -157,7 +157,7 @@ template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
bool>::type
|
||||
bool>::type CK_TILE_HOST
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
@@ -182,7 +182,7 @@ check_err(const Range& out,
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
|
||||
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
@@ -220,11 +220,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
#endif
|
||||
,
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double = 0,
|
||||
double atol = 0)
|
||||
CK_TILE_HOST check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double = 0,
|
||||
double atol = 0)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
@@ -270,12 +270,12 @@ template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
CK_TILE_HOST check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
@@ -323,12 +323,12 @@ template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
CK_TILE_HOST check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
// To be removed, which really does not tell the location of failed HIP functional call
|
||||
inline void hip_check_error(hipError_t x)
|
||||
CK_TILE_HOST void hip_check_error(hipError_t x)
|
||||
{
|
||||
if(x != hipSuccess)
|
||||
{
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Range>
|
||||
std::ostream& LogRange(std::ostream& os,
|
||||
Range&& range,
|
||||
std::string delim,
|
||||
int precision = std::cout.precision(),
|
||||
int width = 0)
|
||||
CK_TILE_HOST std::ostream& LogRange(std::ostream& os,
|
||||
Range&& range,
|
||||
std::string delim,
|
||||
int precision = std::cout.precision(),
|
||||
int width = 0)
|
||||
{
|
||||
bool first = true;
|
||||
for(auto&& v : range)
|
||||
@@ -37,11 +37,11 @@ std::ostream& LogRange(std::ostream& os,
|
||||
}
|
||||
|
||||
template <typename T, typename Range>
|
||||
std::ostream& LogRangeAsType(std::ostream& os,
|
||||
Range&& range,
|
||||
std::string delim,
|
||||
int precision = std::cout.precision(),
|
||||
int width = 0)
|
||||
CK_TILE_HOST std::ostream& LogRangeAsType(std::ostream& os,
|
||||
Range&& range,
|
||||
std::string delim,
|
||||
int precision = std::cout.precision(),
|
||||
int width = 0)
|
||||
{
|
||||
bool first = true;
|
||||
for(auto&& v : range)
|
||||
@@ -56,13 +56,13 @@ std::ostream& LogRangeAsType(std::ostream& os,
|
||||
}
|
||||
|
||||
template <typename F, typename T, std::size_t... Is>
|
||||
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
|
||||
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
|
||||
{
|
||||
return f(std::get<Is>(args)...);
|
||||
}
|
||||
|
||||
template <typename F, typename T>
|
||||
auto call_f_unpack_args(F f, T args)
|
||||
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
|
||||
{
|
||||
constexpr std::size_t N = std::tuple_size<T>{};
|
||||
|
||||
@@ -70,13 +70,13 @@ auto call_f_unpack_args(F f, T args)
|
||||
}
|
||||
|
||||
template <typename F, typename T, std::size_t... Is>
|
||||
auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
|
||||
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
|
||||
{
|
||||
return F(std::get<Is>(args)...);
|
||||
}
|
||||
|
||||
template <typename F, typename T>
|
||||
auto construct_f_unpack_args(F, T args)
|
||||
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
|
||||
{
|
||||
constexpr std::size_t N = std::tuple_size<T>{};
|
||||
|
||||
@@ -87,7 +87,19 @@ struct HostTensorDescriptor
|
||||
{
|
||||
HostTensorDescriptor() = default;
|
||||
|
||||
void CalculateStrides();
|
||||
void CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
mStrides.resize(mLens.size(), 0);
|
||||
if(mStrides.empty())
|
||||
return;
|
||||
|
||||
mStrides.back() = 1;
|
||||
std::partial_sum(mLens.rbegin(),
|
||||
mLens.rend() - 1,
|
||||
mStrides.rbegin() + 1,
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
|
||||
HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
|
||||
@@ -123,12 +135,28 @@ struct HostTensorDescriptor
|
||||
{
|
||||
}
|
||||
|
||||
std::size_t get_num_of_dimension() const;
|
||||
std::size_t get_element_size() const;
|
||||
std::size_t get_element_space_size() const;
|
||||
std::size_t get_num_of_dimension() const { return mLens.size(); }
|
||||
std::size_t get_element_size() const
|
||||
{
|
||||
assert(mLens.size() == mStrides.size());
|
||||
return std::accumulate(
|
||||
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
|
||||
}
|
||||
std::size_t get_element_space_size() const
|
||||
{
|
||||
std::size_t space = 1;
|
||||
for(std::size_t i = 0; i < mLens.size(); ++i)
|
||||
{
|
||||
if(mLens[i] == 0)
|
||||
continue;
|
||||
|
||||
const std::vector<std::size_t>& get_lengths() const;
|
||||
const std::vector<std::size_t>& GetStrides() const;
|
||||
space += (mLens[i] - 1) * mStrides[i];
|
||||
}
|
||||
return space;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& get_lengths() const { return mLens; }
|
||||
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
|
||||
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
@@ -151,8 +179,8 @@ struct HostTensorDescriptor
|
||||
};
|
||||
|
||||
template <typename New2Old>
|
||||
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
|
||||
const New2Old& new2old)
|
||||
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(
|
||||
const HostTensorDescriptor& a, const New2Old& new2old)
|
||||
{
|
||||
std::vector<std::size_t> new_lengths(a.get_num_of_dimension());
|
||||
std::vector<std::size_t> new_strides(a.get_num_of_dimension());
|
||||
@@ -238,7 +266,7 @@ struct ParallelTensorFunctor
|
||||
};
|
||||
|
||||
template <typename F, typename... Xs>
|
||||
auto make_ParallelTensorFunctor(F f, Xs... xs)
|
||||
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
|
||||
{
|
||||
return ParallelTensorFunctor<F, Xs...>(f, xs...);
|
||||
}
|
||||
|
||||
@@ -20,12 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(const stream_config& s,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TILE_TIME_KERNEL
|
||||
if(s.time_kernel_)
|
||||
@@ -75,13 +75,13 @@ float launch_and_time_kernel(const stream_config& s,
|
||||
}
|
||||
|
||||
template <typename... Args, typename F, typename PreProcessFunc>
|
||||
float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TILE_TIME_KERNEL
|
||||
if(s.time_kernel_)
|
||||
@@ -151,12 +151,12 @@ template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
float launch_kernel(const stream_config& s,
|
||||
KernelImpl kernel_impl,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t dynamic_smem_byte,
|
||||
Args... args)
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s,
|
||||
KernelImpl kernel_impl,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t dynamic_smem_byte,
|
||||
Args... args)
|
||||
{
|
||||
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
// ranges implementation are not intented to be used by user
|
||||
// TODO: do we need this?
|
||||
namespace ck_tile {
|
||||
namespace ranges {
|
||||
|
||||
template <typename T>
|
||||
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
|
||||
@@ -21,8 +20,7 @@ using iter_reference_t = decltype(*std::declval<T&>());
|
||||
template <typename T>
|
||||
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
|
||||
|
||||
//.........................
|
||||
|
||||
namespace ranges {
|
||||
template <typename R>
|
||||
using iterator_t = decltype(std::begin(std::declval<R&>()));
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ template <typename ADataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename BinaryElementOp = ck_tile::plus<AccDataType>>
|
||||
void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
|
||||
const HostTensor<BDataType>& b_b_m_n,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const BinaryElementOp& binary_element_op = {})
|
||||
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
|
||||
const HostTensor<BDataType>& b_b_m_n,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const BinaryElementOp& binary_element_op = {})
|
||||
{
|
||||
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ template <typename ADataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
const HostTensor<BDataType>& b_b_n_k,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
const HostTensor<BDataType>& b_b_n_k,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CDataType, typename MaskingType>
|
||||
void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
|
||||
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
|
||||
{
|
||||
const int M = c_b_m_n.mDesc.get_lengths()[1];
|
||||
const int N = c_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename CompDataType, typename BDataType>
|
||||
void reference_batched_softmax(
|
||||
CK_TILE_HOST void reference_batched_softmax(
|
||||
const HostTensor<ADataType>& a_b_m_n,
|
||||
HostTensor<BDataType>& b_b_m_n,
|
||||
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
|
||||
|
||||
@@ -16,12 +16,12 @@ template <typename ADataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_n_k,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_n_k,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int K = b_n_k.mDesc.get_lengths()[1];
|
||||
|
||||
@@ -10,25 +10,25 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
void reference_im2col(HostTensor<T>& in_mtx_host_ref,
|
||||
const HostTensor<T>& in_host,
|
||||
int /*N*/,
|
||||
int /*K*/,
|
||||
int C,
|
||||
int /*Y*/,
|
||||
int X,
|
||||
int Hi,
|
||||
int Wi,
|
||||
int Ho,
|
||||
int Wo,
|
||||
int ConvStrideH,
|
||||
int ConvStrideW,
|
||||
int ConvDilationH,
|
||||
int ConvDilationW,
|
||||
int InLeftPadH,
|
||||
int InLeftPadW,
|
||||
int /*InRightPadH*/,
|
||||
int /*InRightPadW*/)
|
||||
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
|
||||
const HostTensor<T>& in_host,
|
||||
int /*N*/,
|
||||
int /*K*/,
|
||||
int C,
|
||||
int /*Y*/,
|
||||
int X,
|
||||
int Hi,
|
||||
int Wi,
|
||||
int Ho,
|
||||
int Wo,
|
||||
int ConvStrideH,
|
||||
int ConvStrideW,
|
||||
int ConvDilationH,
|
||||
int ConvDilationW,
|
||||
int InLeftPadH,
|
||||
int InLeftPadW,
|
||||
int /*InRightPadH*/,
|
||||
int /*InRightPadW*/)
|
||||
{
|
||||
int GemmM = in_mtx_host_ref.get_lengths()[0];
|
||||
int GemmK = in_mtx_host_ref.get_lengths()[1];
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
|
||||
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = a_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
@@ -10,12 +10,13 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
void reference_softmax(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m_n)
|
||||
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
|
||||
HostTensor<BDataType>& b_m_n)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = a_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
AccDataType v_max = ck_tile::NumericLimits<ADataType>::Lowest();
|
||||
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
|
||||
@@ -575,9 +575,8 @@ struct FmhaFwdKernel
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
/// FIXME: Before C++20, capturing structured binding variables is not supported. Remove
|
||||
/// following copy capture of the 'i_nhead'
|
||||
/// if compiled in C++20
|
||||
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
|
||||
/// following copy capture of the 'i_nhead' if in C++20
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto bias_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -189,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -208,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -346,12 +347,15 @@ struct BlockFmhaPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,7 +364,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -375,7 +379,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -231,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -251,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -389,12 +390,15 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,7 +407,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -454,7 +458,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -181,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -329,12 +330,15 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,7 +347,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -358,7 +362,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -338,12 +338,15 @@ struct BlockFmhaPipelineQSKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,7 +355,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -367,7 +370,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
|
||||
@@ -97,9 +102,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
}
|
||||
@@ -222,9 +226,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
}
|
||||
@@ -918,12 +921,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType>,
|
||||
2>>{};
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
|
||||
typename Problem::VDataType>,
|
||||
2>>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user