// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include #include #include #include #include "ck/ck.hpp" namespace ck { namespace host_common { template static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) { std::ofstream outFile(fileName, std::ios::binary); if(outFile) { outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); outFile.close(); std::cout << "Write output to file " << fileName << std::endl; } else { std::cout << "Could not open file " << fileName << " for writing" << std::endl; } }; template static inline T getSingleValueFromString(const std::string& valueStr) { std::istringstream iss(valueStr); T val; iss >> val; return (val); }; template static inline std::vector getTypeValuesFromString(const char* cstr_values) { std::string valuesStr(cstr_values); std::vector values; std::size_t pos = 0; std::size_t new_pos; new_pos = valuesStr.find(',', pos); while(new_pos != std::string::npos) { const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); T val = getSingleValueFromString(sliceStr); values.push_back(val); pos = new_pos + 1; new_pos = valuesStr.find(',', pos); }; std::string sliceStr = valuesStr.substr(pos); T val = getSingleValueFromString(sliceStr); values.push_back(val); return (values); } template static inline std::vector> get_index_set(const std::array& dim_lengths) { static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); if constexpr(NDim == 1) { std::vector> index_set; for(int i = 0; i < dim_lengths[0]; i++) { std::array index{i}; index_set.push_back(index); }; return index_set; } else { std::vector> index_set; std::array partial_dim_lengths; std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin()); std::vector> partial_index_set; partial_index_set = get_index_set(partial_dim_lengths); for(index_t i = 0; i < dim_lengths[0]; i++) for(const auto& partial_index : partial_index_set) { std::array index; index[0] = i; std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1); index_set.push_back(index); }; return index_set; }; }; template static inline size_t get_offset_from_index(const std::array& strides, const std::array& index) { size_t offset = 0; for(int i = 0; i < NDim; i++) offset += index[i] * strides[i]; return (offset); }; } // namespace host_common } // namespace ck