mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Codegen hipRTC compilation (#1579)
* updating codegen build for MIOpen access: adding .cmake for codegen component * updating CMake * adding in header guards for some headers due to issues with hiprtc compilation in MIOpen * some more header guards * putting env file in header guard * cleaning up some includes * updated types file for hiprtc purposes * fixed types file: bit-wise/memcpy issue * updating multiple utility files to deal with standard header inclusion for hiprtc * added some more header guards in the utility files, replacing some standard header functionality * added some more header guards * fixing some conflicts in utility files, another round of header guards * fixing errors in data type file * resolved conflict errors in a few utility files * added header guards/replicated functionality in device files * resolved issues with standard headers in device files: device_base and device_grouped_conv_fwd_multiple_abd * resolved issues with standard headers in device files: device_base.hpp, device_grouped_conv_fwd_multiple_abd.hpp, device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp * added header guards for gridwise gemm files: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp and gridwise_gemm_multiple_d_xdl_cshuffle.hpp * fixed issue with numerics header, removed from transform_conv_fwd_to_gemm and added to device_column_to_image_impl, device_grouped_conv_fwd_multiple_abd_xdl_cshuffle, device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3, device_image_to_column_impl * replaced standard header usage and added header guards in block to ctile map and gridwise_gemm_pipeline_selector * resolved errors in device_gemm_xdl_splitk_c_shuffle files in regards to replacement of standard headers in previous commit * added replicated functionality for standard header methods in utility files * replaced standard header functionality in threadwise tensor slice transfer files and added header guards in element_wise_operation.hpp * temp fix for namespace error in MIOpen * remove standard header usage in codegen device op * removed standard header usage in elementwise files, resolved namespace errors * formatting fix * changed codegen argument to ON for testing * temporarily removing codegen compiler flag for testing purposes * added codegen flag again, set default to ON * set codegen flag default back to OFF * replaced enable_if_t standard header usage in data_type.hpp * added some debug prints to pinpoint issues in MIOpen * added print outs to debug in MIOpen * removed debug print outs from device op * resolved stdexcept include error * formatting fix * adding includes to new fp8 file to resolve ck::enable_if_t errors * made changes to amd_wave_read_first_lane * updated functionality in type utility file * fixed end of file issue * resovled errors in type utility file, added functionality to array utility file * fixed standard header usage replication in data_type file, resolves error with failing examples on navi3x * formatting fix * replaced standard header usage in amd_ck_fp8 file * added include to random_gen file * removed and replicated standard header usage from data_type and type_convert files for fp8 changes * replicated standard unsigned integer types in random_gen * resolved comments from review: put calls to reinterpret_cast for size_t in header guards * updated/added copyright headers * removed duplicate header * fixed typo in header guard * updated copyright headers --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1,22 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using value_t = integral_constant<bool, false>;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using value_t = integral_constant<bool, true>;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
template <typename T>
|
||||
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
|
||||
using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
|
||||
using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
|
||||
using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user