// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck/utility/common_header.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/index_expression.hpp" namespace ck { /** * @brief Invokes multiple functors based on an index parameter * @tparam Funcs Parameter pack of functor types * @details Stores a tuple of functors and provides an operator() that invokes all of them * with the same index parameter. Uses static_for to iterate through the functors. */ template struct FunctorInvoker { ck::Tuple funcs; __host__ __device__ constexpr FunctorInvoker(Funcs... fs) : funcs(ck::forward(fs)...) {} /** * @brief Invokes all functors with the given index * @tparam I The index to pass to each functor * @param i Number wrapper containing the index value */ template __host__ __device__ constexpr void operator()(ck::Number i) const { invoke(i, std::index_sequence_for{}); } private: template __host__ __device__ constexpr void invoke(ck::Number i, std::index_sequence) const { (funcs[ck::Number(Is)>{}](i), ...); } }; // required for CTAD to work with __host__ __device__ qualifiers template __host__ __device__ constexpr auto MakeFunctorInvoker(Fs&&... fs) { return FunctorInvoker{ck::forward(fs)...}; } /** * @brief Helper struct for evaluating compile-time index expressions * @tparam T The expression type to evaluate * @tparam ik The index variable value * @details Provides a value member that evaluates the index expression T using * the index_expression::eval_v */ template struct IndexEval; template struct IndexEval : IndexEval { }; template struct IndexEval, ik> { static constexpr index_t value = v; }; template struct IndexEval { static constexpr index_t value = ik; }; template struct IndexEval, ik> { static constexpr index_t value = IndexEval::value + IndexEval::value; }; template struct IndexEval, ik> { static constexpr index_t value = IndexEval::value * IndexEval::value; }; template struct IndexEval, ik> { static constexpr index_t divisor = IndexEval::value; static_assert(divisor != 0, "ck::index_expression::Div: division by zero in compile-time index expression"); static constexpr index_t value = IndexEval::value / divisor; }; template struct IndexEval, ik> { static constexpr index_t divisor = IndexEval::value; static_assert(divisor != 0, "ck::index_expression::Mod: modulo by zero in compile-time index expression"); static constexpr index_t value = IndexEval::value % divisor; }; /** * @brief Loads thread elements from buffer to vector using compile-time index expressions * @tparam ThreadVec The vector type to load into * @tparam ThreadBuf The buffer type to load from * @tparam ThreadDesc The descriptor for thread memory layout * @tparam ComputeType The computation type for the result * @tparam IdxExpr Parameter pack of compile-time index expressions * @details Uses index expressions to compute offsets in ThreadBuf and loads the values * into the ThreadVec. The operator() accepts a compile-time index and evaluates all * index expressions for that particular index value. * * Example: * @code * // Load from buffer using index expressions Ik (the loop index) and Number<5> * using Loader = thread_buf_to_vec_loader>; * Loader loader{thread_vec, thread_buf}; * loader(Number<3>{}); // Loads at offset computed by evaluating expressions with ik=3 * @endcode */ template struct thread_buf_to_vec_loader { ThreadVec& thread_vec; ThreadBuf& thread_buf; __host__ __device__ constexpr thread_buf_to_vec_loader(ThreadVec& tv, ThreadBuf& tb) : thread_vec(tv), thread_buf(tb) { } /** * @brief Loads a single element from buffer to vector for the given index * @tparam ik The index value for which to evaluate the index expressions */ template __host__ __device__ constexpr void operator()(Number) const { // TODO c++20: ThreadDesc could be an auto parameter, but clang doesn't support auto // non-type template parameters yet constexpr auto thread_desc = ThreadDesc{}; constexpr auto idx_tuple = ck::make_tuple(Number::value>{}...); constexpr auto offset = thread_desc.CalculateOffset(idx_tuple); auto& target = thread_vec.template AsType()(Number{}); target = thread_buf[Number{}]; } }; } // namespace ck