diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 87cb8cdf11..8654170b3d 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -95,6 +95,7 @@ else() -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage + -Wno-unused-lambda-capture ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 9e11f4d19e..f54049cfcc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "fmha_fwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + #include #include #include @@ -9,11 +14,24 @@ #include #include #include +#include -#include "fmha_fwd.hpp" -#include "ck_tile/host.hpp" -#include "mask.hpp" -#include "utils.hpp" +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} auto create_args(int argc, char* argv[]) { @@ -91,12 +109,12 @@ auto get_elimit(int init_method) template bool run(const ck_tile::ArgParser& arg_parser) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); if(nhead_k == 0) nhead_k = nhead; @@ -143,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - stream_config stream_config{ + ck_tile::stream_config stream_config{ nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); @@ -207,53 +225,57 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); - HostTensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - HostTensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - HostTensor v_host( + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host // will not be used for verification at all (but will be copied to device anyway). - HostTensor bias_host( - use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor bias_host( + use_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] - HostTensor lse_host( + ck_tile::HostTensor lse_host( lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); - HostTensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); if(init_method == 0) { - ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); } else if(init_method == 1) { - ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } else if(init_method == 2) { - ck_tile::utils::FillTrigValue{}(q_host); - ck_tile::utils::FillTrigValue{}(k_host); - ck_tile::utils::FillTrigValue{}(v_host); - ck_tile::utils::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); } - DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); - DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); - DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); - DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); - DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -349,19 +371,21 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); - const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_lengths = + std::array{nhead, hdim_v, real_seqlen_k}; const auto v_host_ref_strides = - is_v_rowmajor ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} - : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + is_v_rowmajor + ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; - HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); - HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - HostTensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); - HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); - HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - HostTensor lse_host_ref({nhead, real_seqlen_q}); + ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); ck_tile::index_t nr = nhead / nhead_k; @@ -386,7 +410,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on // reference - reference_batched_gemm( + ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, @@ -396,7 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(use_bias) { - HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); @@ -406,43 +430,43 @@ bool run(const ck_tile::ArgParser& arg_parser) // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, // real_seqlen_k] - reference_batched_elementwise( + ck_tile::reference_batched_elementwise( s_host_ref, bias_host_ref, s_host_ref); } if(mask.type == mask_enum::no_mask) { - reference_batched_masking( + ck_tile::reference_batched_masking( s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); } else if(mask.type == mask_enum::window_generic) { - reference_batched_masking( + ck_tile::reference_batched_masking( s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } else { - reference_batched_masking( + ck_tile::reference_batched_masking( s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } if(lse) { - reference_batched_softmax( + ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, lse_host_ref); } else { - reference_batched_softmax( + ck_tile::reference_batched_softmax( s_host_ref, p_host_ref); } - reference_batched_gemm( + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref); - HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off // permute if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); @@ -450,7 +474,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on auto [rtol, atol] = get_elimit(init_method); - bool cur_pass = ck_tile::utils::check_err( + bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); pass &= cur_pass; if(!cur_pass) @@ -466,17 +490,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { - HostTensor lse_host_result({nhead, real_seqlen_q}); + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); lse_host_result.ForEach([&](auto& self, auto idx) { self(idx) = lse_host(b, idx[0], idx[1] + query_offset); }); - bool lse_pass = ck_tile::utils::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + bool lse_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); pass &= lse_pass; if(!cur_pass) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 325ff6b78a..49846a322d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include template struct FmhaFwdTypeConfig; @@ -19,11 +20,11 @@ struct FmhaFwdTypeConfig using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; using BiasDataType = ck_tile::half_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation + using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::half_t; }; @@ -34,11 +35,11 @@ struct FmhaFwdTypeConfig using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation + using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::bf16_t; }; @@ -48,12 +49,12 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t; - using BiasDataType = float; // TODO: fix me - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using BiasDataType = float; // TODO: fix me + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation + using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::fp8_t; }; @@ -64,11 +65,11 @@ struct FmhaFwdTypeConfig using KDataType = ck_tile::bf8_t; using VDataType = ck_tile::bf8_t; using BiasDataType = ck_tile::bf8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation + using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::bf8_t; }; @@ -107,7 +108,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, ck_tile::index_t mask_x) { constexpr bool is_v_rowmajor = - ck_tile::is_same_v; + std::is_same_v; assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, @@ -298,26 +299,26 @@ template ; - static constexpr bool kIsGroupMode = kIsGroupMode_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN0 = kN0_; static constexpr ck_tile::index_t kK0 = kK0_; static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kK1 = kK1_; static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kHasBias = kHasBias_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; }; template -float fmha_fwd_(const stream_config&, fmha_fwd_args); +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); // This is the public API, will be generated by script struct fmha_fwd_traits @@ -332,4 +333,4 @@ struct fmha_fwd_traits bool has_lse; // TODO: padding check is inside this api }; -float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const stream_config&); +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 66feae6a5d..7287fef8a2 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -103,12 +103,12 @@ using fmha_pipeline_{F_idx} = {F_pipeline}< fmha_pipeline_problem_{F_idx}>; using fmha_epilogue_{F_idx} = - ck_tile::FmhaFwdEpilogue::OaccDataType, + ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel, + ck_tile::FmhaFwdKernel, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; @@ -117,7 +117,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F #include template<> -float fmha_fwd_(const stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -131,7 +131,7 @@ float fmha_fwd_(const stream_config& s, fmha_fwd_args a) FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const stream_config& s){{ +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 107f1f61d0..d652172ede 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -51,7 +51,7 @@ struct mask_info printf("not supported value %s, %s\n", v.c_str(), str.c_str()); assert(0); } - tmp.type = mask_enum::window_generic; + tmp.type = mask_enum::window_generic; ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); // TODO: some validation diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py new file mode 100644 index 0000000000..fdc0dcf5d7 --- /dev/null +++ b/example/ck_tile/remod.py @@ -0,0 +1,21 @@ +import pathlib +from pathlib import Path +import subprocess +import os +import copy + +all_files = [] +for p in sorted(Path("./").rglob("*")): + if p.suffix in ['.hpp', '.cpp']: + all_files.append(pathlib.PurePath(p)) + + +# formatting +for x in all_files: + subprocess.Popen(f'dos2unix {str(x)}', shell=True) + cmd = f'clang-format-12 -style=file -i {str(x)}' + #for xp in x.parents: + #print(get_file_base(x)) + subprocess.Popen(cmd, shell=True) + +#print(all_files) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 9a7c95f4c2..61ccde3804 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -336,8 +336,8 @@ struct buffer_store<2> index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { - static_assert(sizeof(T) == 4); - using mbuf_t = float; + static_assert(sizeof(T) == 2); + using mbuf_t = short; asm volatile( "buffer_store_short %0, %1, %2, %3 offen offset:%4" : @@ -468,9 +468,9 @@ struct buffer_store_if<2> index_t i_offset /*max 0xFFF*/, index_t flag = 1) { - static_assert(sizeof(T) == 4); + static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" @@ -606,116 +606,116 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) } // buffer load i8 -CK_TILE_DEVICE int8_t +CK_TILE_DEVICE_EXTERN int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); -CK_TILE_DEVICE int8x2_t +CK_TILE_DEVICE_EXTERN int8x2_t llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); -CK_TILE_DEVICE int8x4_t +CK_TILE_DEVICE_EXTERN int8x4_t llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); // buffer load i16 -CK_TILE_DEVICE int16_t +CK_TILE_DEVICE_EXTERN int16_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); -CK_TILE_DEVICE int16x2_t +CK_TILE_DEVICE_EXTERN int16x2_t llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); -CK_TILE_DEVICE int16x4_t +CK_TILE_DEVICE_EXTERN int16x4_t llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); // buffer load i32 -CK_TILE_DEVICE int32_t +CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); -CK_TILE_DEVICE int32x2_t +CK_TILE_DEVICE_EXTERN int32x2_t llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); -CK_TILE_DEVICE int32x4_t +CK_TILE_DEVICE_EXTERN int32x4_t llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); // buffer load fp16 -CK_TILE_DEVICE fp16_t +CK_TILE_DEVICE_EXTERN _Float16 llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); -CK_TILE_DEVICE fp16x2_t +CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); -CK_TILE_DEVICE fp16x4_t +CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); // buffer load fp32 -CK_TILE_DEVICE float +CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); -CK_TILE_DEVICE fp32x2_t +CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); -CK_TILE_DEVICE fp32x4_t +CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); // buffer store i8 -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, int32x4_t rsrc, index_t voffset, @@ -723,43 +723,43 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); // buffer store i16 -CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata, +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); -CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata, +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); -CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata, +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); // buffer store i32 -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, index_t voffset, @@ -767,21 +767,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); // buffer store fp16 -CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata, +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata, int32x4_t rsrc, index_t voffset, @@ -789,21 +789,21 @@ llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); // buffer store fp32 -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata, int32x4_t rsrc, index_t voffset, @@ -811,7 +811,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); // buffer atomic-add fp16 -CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( +CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( fp16x2_t vdata, int32x4_t rsrc, index_t voffset, @@ -819,7 +819,7 @@ CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); // buffer atomic-add i32 -CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( +CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, index_t voffset, @@ -827,7 +827,7 @@ CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); // buffer atomic-add fp32 -CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32( +CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, index_t voffset, @@ -835,7 +835,7 @@ CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32( index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); // buffer atomic-max fp64 -CK_TILE_DEVICE double +CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int32x4_t rsrc, // dst_wave_buffer_resource int voffset, // dst_thread_addr_offset @@ -1370,7 +1370,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_fp16(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1421,7 +1421,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1429,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1437,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, } else if constexpr(N == 4) { - llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1446,14 +1446,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, else if constexpr(N == 8) { llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<0>{}], + src_thread_data.template get_as()[number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<1>{}], + src_thread_data.template get_as()[number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(bf16_t), @@ -1968,7 +1968,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const array& src_thread_data, } // Direct loads from global to LDS. -CK_TILE_DEVICE void +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, __attribute__((address_space(3))) uint32_t* lds_ptr, index_t size, diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 333168fd2a..888f0e728f 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -58,4 +58,36 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } +CK_TILE_DEVICE void block_sync_lds() +{ +#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} + +CK_TILE_DEVICE void block_sync_lds_direct_load() +{ + asm volatile("\ + s_waitcnt vmcnt(0) \n \ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +} + +CK_TILE_DEVICE void s_nop() +{ +#if 1 + asm volatile("\ + s_nop 0 \n \ + " ::); +#else + __builtin_amdgcn_sched_barrier(0); +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 1ab2ba1002..42508e66a6 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -9,6 +9,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include namespace ck_tile { @@ -24,4 +27,36 @@ CK_TILE_DEVICE void m0_inc_with_memory(index_t v) asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory"); } +template +CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta) +{ +#if 0 + return __shfl_up(v_local, lane_delta); +#elif 1 + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const uint32_t wrap_around_lane_delta = warpSize - lane_delta; + + const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( + (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast(v_local)); + + return bit_cast(v_remote_tmp); +#endif +} + +template +CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) +{ +#if 0 + return __shfl_down(v_local, lane_delta); +#elif 1 + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( + (__lane_id() << 2) + (lane_delta << 2), bit_cast(v_local)); + + return bit_cast(v_remote_tmp); +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index b655ae0a6c..3a318ec091 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -9,13 +9,15 @@ #endif #ifdef __HIPCC__ -#define CK_TILE_HOST __host__ -#define CK_TILE_DEVICE __device__ -#define CK_TILE_HOST_DEVICE __host__ __device__ +#define CK_TILE_HOST inline __host__ +#define CK_TILE_DEVICE inline __device__ +#define CK_TILE_HOST_DEVICE inline __host__ __device__ +#define CK_TILE_DEVICE_EXTERN __device__ #else #define CK_TILE_HOST inline #define CK_TILE_DEVICE inline #define CK_TILE_HOST_DEVICE inline +#define CK_TILE_DEVICE_EXTERN #endif #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 @@ -122,7 +124,7 @@ #endif #ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1 +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ defined(__gfx942__) // for GPU code @@ -132,3 +134,7 @@ #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif + +#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 +#endif diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 7752f31375..7f2041494b 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -21,7 +21,12 @@ struct array { using value_type = T_; static constexpr index_t N = N_; + // TODO: do we need this? + // using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type)))); + // union { value_type data[N]; + // bulk_type __content; + //}; CK_TILE_HOST_DEVICE constexpr array() : data{} {} // TODO: will initialize the data[] with the last value repeatedly // behavior different from std @@ -44,18 +49,24 @@ struct array data[i] = vlast; } } - CK_TILE_HOST_DEVICE explicit constexpr array(value_type c) + template + CK_TILE_HOST_DEVICE explicit constexpr array(Y c) { for(auto i = 0; i < size(); i++) - data[i] = c; - } - template - CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o) - { - static_assert(ArrayType::size() == size(), "wrong! size not the same"); - for(auto i = 0; i < size(); i++) - data[i] = o.data[i]; + data[i] = static_cast(c); } + // template + // CK_TILE_HOST_DEVICE constexpr array(const array& o) + // { + // // static_assert(ArrayType::size() == size(), "wrong! size not the same"); + // __content = o.__content; + // } + // CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o) + // { + // // static_assert(ArrayType::size() == size(), "wrong! size not the same"); + // __content = o.__content; + // return *this; + // } CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; } @@ -147,10 +158,10 @@ struct vector_traits> }; template -CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs) +CK_TILE_HOST_DEVICE constexpr auto make_array(Ts&&... xs) { using value_type = remove_cvref_t; - return array{std::forward(x), std::forward(xs)...}; + return array{std::forward(xs)...}; } // make empty array diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index eec15d2538..474eda80d1 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -484,7 +484,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) // constexpr index_t can't be captured "-Wunused-lambda-capture" // TODO: this is ugly #define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ - [a_of_b_impl, bs_sizes] { \ + [a_of_b_impl, bs_sizes] { \ return ck_tile::generate_tuple( \ [=](auto i) { \ constexpr auto b_impl = a_of_b_impl[i]; \ @@ -496,5 +496,4 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) }() #endif - } // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 581e3b8d61..e10ef40111 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -976,7 +976,7 @@ reduce_on_sequence(Seq, Reduce f, number /*initial_value*/) for(index_t i = 0; i < Seq::size(); ++i) { - result = f(result, Seq::get(i)); + result = f(result, Seq::at(i)); } return result; @@ -990,7 +990,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f) for(index_t i = 0; i < Seq::size(); ++i) { - flag = flag || f(Seq::get(i)); + flag = flag || f(Seq::at(i)); } return flag; @@ -1004,7 +1004,7 @@ CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f) for(index_t i = 0; i < Seq::size(); ++i) { - flag = flag && f(Seq::get(i)); + flag = flag && f(Seq::at(i)); } return flag; @@ -1039,11 +1039,14 @@ CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } -// template -// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple...>) -// { -// return sequence{}; -// } +template +struct tuple; + +template +CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple...>) +{ + return sequence{}; +} namespace detail { template diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index a47cf94811..c146cba9cf 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -139,6 +139,26 @@ struct tuple : impl::tuple_base, T...> // { // return {t...}; // } +template +CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple& a, const tuple& b) +{ + bool same = true; + + static_for<0, sizeof...(Xs), 1>{}([&](auto i) { + if(a[i] != b[i]) + { + same = false; + } + }); + + return same; +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple& a, const tuple& b) +{ + return !(a == b); +} template CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs) @@ -237,21 +257,21 @@ template CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x) { return detail::transform_tuples_impl( - f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); + f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } template CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y) { return detail::transform_tuples_impl( - f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); + f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } template CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) { return detail::transform_tuples_impl( - f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); + f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } // By default unroll to the flatten @@ -490,58 +510,58 @@ struct tuple_element> } // namespace std #if 1 -#define TO_TUPLE_OF_NUMBER(a, n) \ - _Pragma("clang diagnostic push") \ - _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \ - [a](ck_tile::sequence) \ - { \ - return ck_tile::tuple{}]>...>{}; \ - } \ - (ck_tile::make_index_sequence{}) \ - _Pragma("clang diagnostic pop") +#define TO_TUPLE_OF_NUMBER(a, n) \ + _Pragma("clang diagnostic push") _Pragma( \ + "clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ + ck_tile::sequence) \ + { \ + return ck_tile::tuple{}]>...>{}; \ + } \ + (ck_tile::make_index_sequence{}) _Pragma("clang diagnostic pop") #else -#define TO_TUPLE_OF_NUMBER(arr, n_) \ - [&arr, n_] { \ - static_assert(arr.size() >= n_, "wrong! out of bound"); \ - \ - static_assert(n_ < 7, "not implemented"); \ - \ - if constexpr(n_ == 0) \ - { \ - return ck_tile::tuple<>{}; \ - } \ - else if constexpr(n_ == 1) \ - { \ - return ck_tile::tuple>{}; \ - } \ - else if constexpr(n_ == 2) \ - { \ - return ck_tile::tuple, number>{}; \ - } \ - else if constexpr(n_ == 3) \ - { \ - return ck_tile::tuple, number, number>{}; \ - } \ - else if constexpr(n_ == 4) \ - { \ - return ck_tile::tuple, number, number, number>{}; \ - } \ - else if constexpr(n_ == 5) \ - { \ - return ck_tile::tuple, \ - number, \ - number, \ - number, \ - number>{}; \ - } \ - else if constexpr(n_ == 6) \ - { \ - return ck_tile::tuple, \ - number, \ - number, \ - number, \ - number, \ - number>{}; \ - } \ +#define TO_TUPLE_OF_NUMBER(arr, n_) \ + [&arr, n_] { \ + static_assert(arr.size() >= n_, "wrong! out of bound"); \ + \ + static_assert(n_ < 7, "not implemented"); \ + \ + if constexpr(n_ == 0) \ + { \ + return ck_tile::tuple<>{}; \ + } \ + else if constexpr(n_ == 1) \ + { \ + return ck_tile::tuple>{}; \ + } \ + else if constexpr(n_ == 2) \ + { \ + return ck_tile::tuple, number>{}; \ + } \ + else if constexpr(n_ == 3) \ + { \ + return ck_tile::tuple, number, number>{}; \ + } \ + else if constexpr(n_ == 4) \ + { \ + return ck_tile:: \ + tuple, number, number, number>{}; \ + } \ + else if constexpr(n_ == 5) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ + else if constexpr(n_ == 6) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ }() #endif diff --git a/include/ck_tile/core/numeric/arithmetic.hpp b/include/ck_tile/core/numeric/arithmetic.hpp index 970ea9ff61..ad45a45e15 100644 --- a/include/ck_tile/core/numeric/arithmetic.hpp +++ b/include/ck_tile/core/numeric/arithmetic.hpp @@ -4,44 +4,36 @@ #pragma once -#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \ - CK_TILE_HOST_DEVICE \ - bool operator==(const type_& x, const type_& y) \ +#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \ + attr_ bool operator==(const type_& x, const type_& y) \ { \ return static_cast(x) == static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - bool operator!=(const type_& x, const type_& y) \ + attr_ bool operator!=(const type_& x, const type_& y) \ { \ return static_cast(x) != static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - bool operator<(const type_& x, const type_& y) \ + attr_ bool operator<(const type_& x, const type_& y) \ { \ return static_cast(x) < static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - bool operator<=(const type_& x, const type_& y) \ + attr_ bool operator<=(const type_& x, const type_& y) \ { \ return static_cast(x) <= static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - bool operator>(const type_& x, const type_& y) \ + attr_ bool operator>(const type_& x, const type_& y) \ { \ return static_cast(x) > static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - bool operator>=(const type_& x, const type_& y) \ + attr_ bool operator>=(const type_& x, const type_& y) \ { \ return static_cast(x) >= static_cast(y); \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator+(const type_& x, const type_& y) \ + attr_ type_ operator+(const type_& x, const type_& y) \ { \ return type_(static_cast(x) + static_cast(y)); \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator-(const type_& x) \ + attr_ type_ operator-(const type_& x) \ { \ constexpr uint32_t bits = sizeof(type_) * 8; \ constexpr uint32_t mask = 1 << (bits - 1); \ @@ -49,66 +41,55 @@ y.data ^= static_cast(mask); \ return y; \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator-(const type_& x, const type_& y) \ + attr_ type_ operator-(const type_& x, const type_& y) \ { \ return type_(static_cast(x) - static_cast(y)); \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator*(const type_& x, const type_& y) \ + attr_ type_ operator*(const type_& x, const type_& y) \ { \ return type_(static_cast(x) * static_cast(y)); \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator/(const type_& x, const type_& y) \ + attr_ type_ operator/(const type_& x, const type_& y) \ { \ return type_(static_cast(x) / static_cast(y)); \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator+=(type_& x, const type_& y) \ + attr_ type_& operator+=(type_& x, const type_& y) \ { \ x = type_(static_cast(x) + static_cast(y)); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator-=(type_& x, const type_& y) \ + attr_ type_& operator-=(type_& x, const type_& y) \ { \ x = type_(static_cast(x) - static_cast(y)); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator*=(type_& x, const type_& y) \ + attr_ type_& operator*=(type_& x, const type_& y) \ { \ x = type_(static_cast(x) * static_cast(y)); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator/=(type_& x, const type_& y) \ + attr_ type_& operator/=(type_& x, const type_& y) \ { \ x = type_(static_cast(x) / static_cast(y)); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator++(type_& x) \ + attr_ type_& operator++(type_& x) \ { \ x = type_(static_cast(x) + 1.f); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_& operator--(type_& x) \ + attr_ type_& operator--(type_& x) \ { \ x = type_(static_cast(x) - 1.f); \ return x; \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator++(type_& x, int) \ + attr_ type_ operator++(type_& x, int) \ { \ type_ y(x); \ x = type_(static_cast(x) + 1.f); \ return y; \ } \ - CK_TILE_HOST_DEVICE \ - type_ operator--(type_& x, int) \ + attr_ type_ operator--(type_& x, int) \ { \ type_ y(x); \ x = type_(static_cast(x) - 1.f); \ diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 6fd433f005..abcc8fdc1a 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -24,9 +24,16 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}); +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant = {}); + CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x); +CK_TILE_HOST_DEVICE +double bf16_to_double_raw(uint16_t x); + // HIP use __hip_bfloat16 as struct struct alignas(2) bfloat16_t { @@ -48,6 +55,10 @@ struct alignas(2) bfloat16_t CK_TILE_HOST_DEVICE explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {} + // construct from double + CK_TILE_HOST_DEVICE + explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {} + // construct from int CK_TILE_HOST_DEVICE explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast(x))) {} @@ -63,6 +74,10 @@ struct alignas(2) bfloat16_t CK_TILE_HOST_DEVICE explicit constexpr operator float() const { return bf16_to_float_raw(data); } + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator double() const { return bf16_to_double_raw(data); } + // cast to int CK_TILE_HOST_DEVICE explicit constexpr operator int() const { return static_cast(bf16_to_float_raw(data)); } @@ -157,6 +172,12 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant) return float_to_bf16_truc_raw(f); } +template +CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant) +{ + return float_to_bf16_raw(static_cast(f), constant{}); +} + CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x) { @@ -168,6 +189,9 @@ float bf16_to_float_raw(uint16_t x) return u.fp32; } +CK_TILE_HOST_DEVICE +double bf16_to_double_raw(uint16_t x) { return static_cast(bf16_to_float_raw(x)); } + template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant) @@ -175,9 +199,19 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant) return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant{})); } +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant) +{ + return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant{})); +} + CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x) { return static_cast(x); } +CK_TILE_HOST_DEVICE +double bf16_to_double(bfloat16_t x) { return static_cast(x); } + template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant = {}) @@ -240,7 +274,7 @@ struct numeric_limits } }; -CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t) +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t) // math CK_TILE_HOST_DEVICE diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 11e971661b..e94c9e7764 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -184,7 +184,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr Y nan_code = 0x80; + constexpr Y nan_code = __builtin_bit_cast(Y, static_cast(0x80)); constexpr uint32_t nan_mask = numeric_utils::nan_mask; // convert to bitwise @@ -215,7 +215,7 @@ CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) // check if x is 0.0 if(x_bitwise == 0) - return 0; + return __builtin_bit_cast(Y, static_cast(0)); // First need to check if it is normal or denorm as there is a difference of implict 1 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift @@ -317,15 +317,18 @@ In this case, the fp16 mantissa should be shift left by 1 */ } else { - return signed_inf; + return __builtin_bit_cast(Y, static_cast(signed_inf)); } } // check if x is 0.0 or -0.0 if(out_exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + return __builtin_bit_cast( + Y, static_cast(negative_zero_nan ? 0 : (sign << (out_exp + out_mant)))); mantissa &= (1 << out_mant) - 1; - return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; + return __builtin_bit_cast(Y, + static_cast((sign << (out_exp + out_mant)) | + (out_exponent << out_mant) | mantissa)); } template @@ -338,9 +341,10 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) // resulting type exponent/mantissa layout constexpr int out_exp = numeric_utils::exp; constexpr int out_mant = numeric_utils::mant; + uint8_t x_raw = __builtin_bit_cast(uint8_t, x); // prepare the codes - constexpr X nan_code = 0x80; + constexpr uint8_t nan_code = 0x80; Y Inf, NegInf, NaN, Neg0; using T_bitwise = typename numeric_utils::bitwise_type; @@ -355,13 +359,13 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) Neg0 = *(reinterpret_cast(&Neg0_bitwise)); // check if x is 0.0 - if(x == 0) + if(x_raw == 0) return static_cast(0); // unpack the input - uint32_t sign = x >> (in_exp + in_mant); - uint32_t mantissa = x & ((1 << in_mant) - 1); - int exponent = (x & 0x7F) >> in_mant; + uint32_t sign = x_raw >> (in_exp + in_mant); + uint32_t mantissa = x_raw & ((1 << in_mant) - 1); + int exponent = (x_raw & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -369,12 +373,12 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(x == nan_code) + if(x_raw == nan_code) return NaN; } else { - if(x == nan_code) + if(x_raw == nan_code) return Neg0; if(exponent == ((1 << in_exp) - 1)) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; @@ -382,7 +386,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) if((numeric_utils::mant == 10) && (numeric_utils::mant == 2) && !negative_zero_nan) { - retval = x; + retval = x_raw; retval <<= 8; return *(reinterpret_cast(&retval)); } @@ -700,8 +704,8 @@ struct numeric_limits CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); } }; -CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t) -CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t) +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t) +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t) // math CK_TILE_HOST_DEVICE diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index b22f71c045..5ad9e3aacd 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/limits.hpp" #include @@ -15,9 +16,15 @@ using fp16_hip_t = __half; // most of hip internal function use this type CK_TILE_HOST_DEVICE float fp16_to_float_hip(const fp16_hip_t& x); +CK_TILE_HOST_DEVICE +double fp16_to_double_hip(const fp16_hip_t& x); + CK_TILE_HOST_DEVICE fp16_hip_t float_to_fp16_hip(const float& x); +CK_TILE_HOST_DEVICE +fp16_hip_t double_to_fp16_hip(const double& x); + // HIP use fp16_hip_t as interchangable data type for float16 struct alignas(2) half_t { @@ -46,6 +53,10 @@ struct alignas(2) half_t CK_TILE_HOST_DEVICE explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {} + // construct from double + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {} + // construct from int CK_TILE_HOST_DEVICE explicit constexpr half_t(const int& x) : half_t(static_cast(__int2half_rn(x))) {} @@ -61,6 +72,10 @@ struct alignas(2) half_t CK_TILE_HOST_DEVICE explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); } + // cast to double + CK_TILE_HOST_DEVICE + explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); } + // cast to int CK_TILE_HOST_DEVICE explicit constexpr operator int() const @@ -87,6 +102,9 @@ float fp16_to_float_hip(const fp16_hip_t& x) return static_cast(x); } +CK_TILE_HOST_DEVICE +double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast(fp16_to_float_hip(x)); } + CK_TILE_HOST_DEVICE fp16_hip_t float_to_fp16_hip(const float& x) { @@ -94,12 +112,25 @@ fp16_hip_t float_to_fp16_hip(const float& x) return static_cast(x); } +CK_TILE_HOST_DEVICE +fp16_hip_t double_to_fp16_hip(const double& x) +{ + // return __float2half(x); + return static_cast(x); +} + CK_TILE_HOST_DEVICE float fp16_to_float(const half_t& x) { return static_cast(x); } +CK_TILE_HOST_DEVICE +float fp16_to_double(const half_t& x) { return static_cast(x); } + CK_TILE_HOST_DEVICE half_t float_to_fp16(const float& x) { return half_t{x}; } +CK_TILE_HOST_DEVICE +half_t double_to_fp16(const double& x) { return half_t{x}; } + // limits template struct numeric_limits; @@ -156,94 +187,94 @@ struct numeric_utils }; // arithmetic -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator+(const half_t& x, const half_t& y) { return half_t(__hadd(x.to_fp16(), y.to_fp16())); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator-(const half_t& x, const half_t& y) { return half_t(__hsub(x.to_fp16(), y.to_fp16())); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator*(const half_t& x, const half_t& y) { return half_t(__hmul(x.to_fp16(), y.to_fp16())); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator/(const half_t& x, const half_t& y) { return half_t(__hdiv(x.to_fp16(), y.to_fp16())); } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator+=(half_t& x, const half_t& y) { x = half_t(__hadd(x.to_fp16(), y.to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator-=(half_t& x, const half_t& y) { x = half_t(__hsub(x.to_fp16(), y.to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator*=(half_t& x, const half_t& y) { x = half_t(__hmul(x.to_fp16(), y.to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator/=(half_t& x, const half_t& y) { x = half_t(__hdiv(x.to_fp16(), y.to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator++(half_t& x) { x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t& operator--(half_t& x) { x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); return x; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator++(half_t& x, int) { half_t y(x); @@ -251,7 +282,7 @@ half_t operator++(half_t& x, int) return y; } -CK_TILE_HOST_DEVICE +CK_TILE_DEVICE half_t operator--(half_t& x, int) { half_t y(x); @@ -259,6 +290,8 @@ half_t operator--(half_t& x, int) return y; } +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t) + // math CK_TILE_HOST_DEVICE half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); } diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 9021b30efd..879f9565b8 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -14,8 +14,9 @@ struct constant using value_type = decltype(v); using type = constant; // using injected-class-name static constexpr value_type value = v; - constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } - constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } + CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } }; template diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 9615b979d5..90a8084b85 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include #include +#include namespace ck_tile { @@ -147,8 +148,8 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& return min(max(x, lowerbound), upperbound); } -CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } -CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); } +CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); } +CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); } // greatest common divisor, aka highest common factor CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y) @@ -246,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x) { // TODO: x need to be 1 ~ 0x7fffffff // __builtin_clz will produce unexpected result if x is 0; - return 31 - clz(x); + return 31 - __builtin_clz(x); } CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x) @@ -275,7 +276,7 @@ struct log2e }; template -inline constexpr T log2e_v = log2e::value; +constexpr T log2e_v = log2e::value; // math CK_TILE_HOST_DEVICE @@ -298,16 +299,32 @@ bool isnan(const float& x) return (xx & 0x7fffffff) > 0x7F800000; } +CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); }; + +CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); }; + CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; +CK_TILE_DEVICE +double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; + CK_TILE_DEVICE float exp(float x) { return __expf(x); }; +CK_TILE_HOST +float exp(float x) { return std::expf(x); } + CK_TILE_DEVICE float exp2(float x) { return exp2f(x); }; +CK_TILE_HOST +float exp2(float x) { return std::exp2f(x); }; + CK_TILE_DEVICE float log(float x) { return __logf(x); }; +CK_TILE_HOST +float log(float x) { return std::logf(x); }; + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index d64f3f4349..81bd55ee86 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -43,11 +43,11 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) return static_cast(type_convert(x)); } -#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ - template <> \ - inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ - { \ - return stype_##_to_##dtype_(x); \ +#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return stype_##_to_##dtype_(x); \ } CK_TILE_TYPE_CONVERT(float, fp16_t) diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index f20891296b..0449a085ab 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -63,12 +63,12 @@ using fp32x32_t = float __attribute__((ext_vector_type(32))); using fp32x64_t = float __attribute__((ext_vector_type(64))); // fp16 -using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2))); -using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4))); -using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8))); -using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16))); -using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32))); -using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64))); +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using fp16x4_t = _Float16 __attribute__((ext_vector_type(4))); +using fp16x8_t = _Float16 __attribute__((ext_vector_type(8))); +using fp16x16_t = _Float16 __attribute__((ext_vector_type(16))); +using fp16x32_t = _Float16 __attribute__((ext_vector_type(32))); +using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); // bfp16 using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); @@ -94,6 +94,14 @@ using int16x16_t = int16_t __attribute__((ext_vector_type(16))); using int16x32_t = int16_t __attribute__((ext_vector_type(32))); using int16x64_t = int16_t __attribute__((ext_vector_type(64))); +// u16 +using uint16x2_t = uint16_t __attribute__((ext_vector_type(2))); +using uint16x4_t = uint16_t __attribute__((ext_vector_type(4))); +using uint16x8_t = uint16_t __attribute__((ext_vector_type(8))); +using uint16x16_t = uint16_t __attribute__((ext_vector_type(16))); +using uint16x32_t = uint16_t __attribute__((ext_vector_type(32))); +using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); + // i8 using int8x2_t = int8_t __attribute((ext_vector_type(2))); using int8x4_t = int8_t __attribute((ext_vector_type(4))); diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index edf3e6eebb..e1bd9c4d19 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -79,8 +79,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT using InVec = array; using OutVec = array; - using InVecType = typename InVec::type; - using OutVecType = typename OutVec::type; + // using InVec = typename InVec::type; + // using OutVec = typename OutVec::type; // SFC constexpr auto scalars_per_access_arr = generate_array( @@ -115,9 +115,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT number{}); constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); - in_vectors(i).template get_as()(I0) = - in_tensor.get_thread_buffer().template get_as(number{}); + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer().template get_as( + number{}); }); // transpose @@ -133,10 +135,11 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); - out_tensor.get_thread_buffer().template set_as( - number{}, - out_vectors[i].template get_as()[I0]); + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); }); }); } diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 2a3ecd8f70..8995dfd403 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ \ - constexpr auto trans = [&encoded_transforms]() { \ + constexpr auto trans = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) constexpr { \ constexpr auto name = encoded_transforms[i].template at<0>(); \ @@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ \ - constexpr auto trans = [&encoded_transforms]() { \ + constexpr auto trans = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) constexpr { \ constexpr auto name = encoded_transforms[i].template at<0>(); \ @@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. number{}); \ }(); \ \ - constexpr auto low_dim_idss = [&encoded_transforms]() { \ + constexpr auto low_dim_idss = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) { \ constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ @@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. number()); \ }(); \ \ - constexpr auto up_dim_idss = [&encoded_transforms] { \ + constexpr auto up_dim_idss = [&encoded_transforms] { \ return generate_tuple( \ [&encoded_transforms](auto i) { \ constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index aa9cf108c5..0c3e04f315 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -90,7 +90,7 @@ struct tensor_descriptor : public tensor_adaptor{}; + constexpr auto h_minor_lengths = + HsLengthss{}.get(idim_x); // std::tuple_element_t{}; // constexpr auto h_minor_lengths = impl::getv(HsLengthss{}); constexpr index_t ndim_h_minor = h_minor_lengths.size(); @@ -532,7 +533,7 @@ struct reverse_slice_sequence_impl, using old_scan = reverse_slice_sequence_impl, sequence, sequence, SliceSize>; - static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value; + static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value; static constexpr auto slice_length = std::conditional_t, number>::value; @@ -546,7 +547,7 @@ struct reverse_slice_sequence_impl, // the first idx that sliced length not equal to original length static constexpr index_t _flag = - slice_length != x && remaining_slice_sizes{}.Front().value == 1; + slice_length != x && remaining_slice_sizes{}.front().value == 1; static constexpr index_t _split_flag = std::conditional_t, number<0>>::value; static constexpr index_t _split_idx = std::conditional_t<_split_flag, number, number<0>>::value; @@ -570,7 +571,7 @@ struct reverse_slice_sequence_impl, sequence, sequence, Slice // the first idx that sliced length not equal to original length static constexpr index_t _flag = - slice_length != x && remaining_slice_sizes{}.Front().value == 1; + slice_length != x && remaining_slice_sizes{}.front().value == 1; static constexpr index_t split_flag = std::conditional_t, number<0>>::value; static constexpr index_t split_idx = std::conditional_t, number<0>>::value; @@ -613,7 +614,7 @@ constexpr auto reverse_slice_sequence(Seq, Mask, typename arithmetic_sequence_gen<0, Seq::size(), 1>::type, SliceSize>; - static_assert(sliced_type::remaining_slice_sizes::Front().value == 1, + static_assert(sliced_type::remaining_slice_sizes::front().value == 1, "can not evenly divide this sequence, please check"); return make_tuple(typename sliced_type::dim_lengths{}, typename sliced_type::dim_slices{}, diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 76db527780..be4c67dbce 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/sequence.hpp" diff --git a/include/ck_tile/core/utility/to_sequence.hpp b/include/ck_tile/core/utility/to_sequence.hpp index 4db6cfd4a0..2276ab68b7 100644 --- a/include/ck_tile/core/utility/to_sequence.hpp +++ b/include/ck_tile/core/utility/to_sequence.hpp @@ -7,14 +7,14 @@ #if 1 // clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode -#define TO_SEQUENCE(a, n) \ - _Pragma("clang diagnostic push") \ - _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \ - [a](ck_tile::sequence) \ - { \ - return ck_tile::sequence{})...>{}; \ - } \ - (ck_tile::make_index_sequence{}); \ +#define TO_SEQUENCE(a, n) \ + _Pragma("clang diagnostic push") _Pragma( \ + "clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ + ck_tile::sequence) \ + { \ + return ck_tile::sequence{})...>{}; \ + } \ + (ck_tile::make_index_sequence{}); \ _Pragma("clang diagnostic pop") #else diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 8b8b01a2ac..f5dffda863 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -22,27 +22,6 @@ using remove_cvref_t = remove_cv_t>; template using remove_pointer_t = typename std::remove_pointer::type; -namespace impl { -template -struct is_static_impl -{ - static constexpr bool value = std::is_arithmetic::v ? false : T::is_static(); -}; -} // namespace impl - -template -using is_static = impl::is_static_impl>; - -template -inline constexpr bool is_static_v = is_static::value; - -// TODO: deprecate this -template -using is_known_at_compile_time = is_static; -// TODO: if evaluating a rvalue, e.g. a const integer -// , this helper will also return false, which is not good(?) -// do we need something like is_constexpr()? - namespace detail { template class Op, class... Args> struct detector @@ -69,6 +48,36 @@ struct nonesuch template