mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
add c-style pointer cast
This commit is contained in:
@@ -3,6 +3,9 @@
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpaceEnum_t
|
||||
@@ -17,15 +20,24 @@ enum AddressSpaceEnum_t
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
|
||||
{
|
||||
return (T*)p;
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only old style cast seems be able to be compiled
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic push
|
||||
return (T*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
return (T CONSTANT*)p;
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only old style cast seems be able to be compiled
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic push
|
||||
return (T CONSTANT*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
20
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
20
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef CK_C_STYLE_POINTER_CAST_HPP
|
||||
#define CK_C_STYLE_POINTER_CAST_HPP
|
||||
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||
__host__ __device__ PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic push
|
||||
return (PY)p_x; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "type.hpp"
|
||||
#include "magic_division.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "amd_address_space.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef CK_DYNAMIC_BUFFER_HPP
|
||||
#define CK_DYNAMIC_BUFFER_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
struct DynamicBuffer
|
||||
@@ -44,20 +45,20 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
p_data_, i, is_valid_offset, element_space_size_);
|
||||
#else
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,17 +79,17 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
||||
#else
|
||||
if(is_valid_offset)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -97,7 +98,7 @@ struct DynamicBuffer
|
||||
if(is_valid_offset)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||
// inefficient
|
||||
@@ -128,24 +129,24 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int8_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int8_t*>(&x);
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int16_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int16_t*>(&x);
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x4_t>::value &&
|
||||
@@ -153,8 +154,8 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x8_t>::value &&
|
||||
@@ -162,8 +163,8 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x2_t*>(&x);
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x16_t>::value &&
|
||||
@@ -171,13 +172,13 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x4_t*>(&x);
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -186,7 +187,7 @@ struct DynamicBuffer
|
||||
{
|
||||
if(is_valid_offset)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,10 +22,7 @@ template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
|
||||
{
|
||||
return static_cast<typename std::remove_reference<T>::type&&>(t);
|
||||
}
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
template <typename T>
|
||||
struct is_known_at_compile_time;
|
||||
|
||||
@@ -290,9 +290,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user