Change all device operations to use add_instance_library (#338)

* Change all device operations to use add_instance_library to avoid duplicated cmake configuration.

* update DeviceMem

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
cloudhan
2022-08-14 01:17:58 +08:00
committed by GitHub
parent 0bd6b842b9
commit fb1cbf025b
31 changed files with 190 additions and 358 deletions

View File

@@ -18,23 +18,26 @@ struct DeviceMem
{
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
void* GetDeviceBuffer();
std::size_t GetBufferSize();
void ToDevice(const void* p);
void FromDevice(void* p);
void SetZero();
void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const;
void ToDevice(const void* p) const;
void FromDevice(void* p) const;
void SetZero() const;
template <typename T>
void SetValue(T x)
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
void SetValue(T x) const;
~DeviceMem();
void* mpDeviceBuf;
std::size_t mMemSize;
};
template <typename T>
void DeviceMem::SetValue(T x) const
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}

View File

@@ -3,6 +3,7 @@ function(add_instance_library INSTANCE_NAME)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(gemm)

View File

@@ -1,26 +1,18 @@
#device_batched_gemm_instance
set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp;
add_instance_library(device_batched_gemm_instance
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
)
add_library(device_batched_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
# target_compile_features(device_batched_gemm_instance PUBLIC)
set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
# install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_batched_gemm_instance)

View File

@@ -1,8 +1,3 @@
set(DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE
add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
add_instance_library(device_batched_gemm_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE})
target_compile_features(device_batched_gemm_gemm_instance PUBLIC)
set_target_properties(device_batched_gemm_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_batched_gemm_gemm_instance)

View File

@@ -1,12 +1,7 @@
set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE
add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
)
add_instance_library(device_batched_gemm_reduce_instance OBJECT ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE})
target_compile_features(device_batched_gemm_reduce_instance PUBLIC)
set_target_properties(device_batched_gemm_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_batched_gemm_reduce_instance)

View File

@@ -1,8 +1,4 @@
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_INSTANCE_SOURCE
add_instance_library(device_batched_gemm_softmax_gemm_instance
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
add_instance_library(device_batched_gemm_softmax_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_INSTANCE_SOURCE})
target_compile_features(device_batched_gemm_softmax_gemm_instance PUBLIC)
set_target_properties(device_batched_gemm_softmax_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_batched_gemm_softmax_gemm_instance)

View File

@@ -1,12 +1,7 @@
# device_contraction_bilinear_instance
set(DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE
add_instance_library(device_contraction_bilinear_instance
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
)
add_library(device_contraction_bilinear_instance OBJECT ${DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE})
set_target_properties(device_contraction_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_contraction_bilinear_instance)

View File

@@ -1,12 +1,7 @@
# device_contraction_scale_instance
set(DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE
add_instance_library(device_contraction_scale_instance
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp
)
add_library(device_contraction_scale_instance OBJECT ${DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE})
set_target_properties(device_contraction_scale_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_contraction_scale_instance)

View File

@@ -1,14 +1,6 @@
# device_conv1d_bwd_data_instance
set(DEVICE_CONV1D_BWD_DATA_INSTANCE_SOURCE
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp;
add_instance_library(device_conv1d_bwd_data_instance
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp
)
add_library(device_conv1d_bwd_data_instance OBJECT ${DEVICE_CONV1D_BWD_DATA_INSTANCE_SOURCE})
target_compile_features(device_conv1d_bwd_data_instance PUBLIC)
set_target_properties(device_conv1d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_conv1d_bwd_data_instance)
clang_tidy_check(device_conv1d_bwd_data_instance)

View File

@@ -1,13 +1,5 @@
#device_conv1d_bwd_weight_instance
set(DEVICE_CONV1D_BWD_WEIGHT_INSTANCE_SOURCE
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp;
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp;
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp;
add_instance_library(device_conv1d_bwd_weight_instance
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp
device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp
)
add_library(device_conv1d_bwd_weight_instance OBJECT ${DEVICE_CONV1D_BWD_WEIGHT_INSTANCE_SOURCE})
target_compile_features(device_conv1d_bwd_weight_instance PUBLIC)
set_target_properties(device_conv1d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_conv1d_bwd_weight_instance)
clang_tidy_check(device_conv1d_bwd_weight_instance)

View File

@@ -1,12 +1,6 @@
# device_conv2d_bwd_data_instance
set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
add_instance_library(device_conv2d_bwd_data_instance
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
)
add_library(device_conv2d_bwd_data_instance OBJECT ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE})
set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_bwd_data_instance)

View File

@@ -1,13 +1,6 @@
#device_conv2d_bwd_weight_instance
set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
add_instance_library(device_conv2d_bwd_weight_instance
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
)
add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE})
target_compile_features(device_conv2d_bwd_weight_instance PUBLIC)
set_target_properties(device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_conv2d_bwd_weight_instance)
clang_tidy_check(device_conv2d_bwd_weight_instance)

View File

@@ -1,12 +1,7 @@
# device_conv2d_fwd_instance
set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
add_instance_library(device_conv2d_fwd_instance
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
)
add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_instance)

View File

@@ -1,8 +1,3 @@
# device_conv2d_fwd_bias_relu_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
add_instance_library(device_conv2d_fwd_bias_relu_instance
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
)
add_library(device_conv2d_fwd_bias_relu_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_bias_relu_instance)

View File

@@ -1,8 +1,4 @@
# device_conv2d_fwd_bias_relu_add_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
add_instance_library(device_conv2d_fwd_bias_relu_add_instance
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
)
add_library(device_conv2d_fwd_bias_relu_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance)

View File

@@ -1,14 +1,6 @@
# device_conv3d_bwd_data_instance
set(DEVICE_CONV3D_BWD_DATA_INSTANCE_SOURCE
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
add_instance_library(device_conv3d_bwd_data_instance
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp
)
add_library(device_conv3d_bwd_data_instance OBJECT ${DEVICE_CONV3D_BWD_DATA_INSTANCE_SOURCE})
target_compile_features(device_conv3d_bwd_data_instance PUBLIC)
set_target_properties(device_conv3d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_conv3d_bwd_data_instance)
clang_tidy_check(device_conv3d_bwd_data_instance)

View File

@@ -1,13 +1,5 @@
#device_conv3d_bwd_weight_instance
set(DEVICE_CONV3D_BWD_WEIGHT_INSTANCE_SOURCE
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp;
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp;
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
add_instance_library(device_conv3d_bwd_weight_instance
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp
device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp
)
add_library(device_conv3d_bwd_weight_instance OBJECT ${DEVICE_CONV3D_BWD_WEIGHT_INSTANCE_SOURCE})
target_compile_features(device_conv3d_bwd_weight_instance PUBLIC)
set_target_properties(device_conv3d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_conv3d_bwd_weight_instance)
clang_tidy_check(device_conv3d_bwd_weight_instance)

View File

@@ -1,10 +1,3 @@
set(DEVICE_ELEMENTWISE_INSTANCE_SOURCE
add_instance_library(device_elementwise_instance
device_normalize_instance.cpp
)
add_instance_library(device_elementwise_instance ${DEVICE_ELEMENTWISE_INSTANCE_SOURCE})
target_compile_features(device_elementwise_instance PUBLIC)
set_target_properties(device_elementwise_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_elementwise_instance)

View File

@@ -1,48 +1,43 @@
set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp;
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp;
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp;
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp;
add_instance_library(device_gemm_instance
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
)
add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE})
target_compile_features(device_gemm_instance PUBLIC)
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)

View File

@@ -1,14 +1,6 @@
# device_gemm_add_add_fastgelu_instance
set(DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp;
add_instance_library(device_gemm_add_add_fastgelu_instance
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
)
add_library(device_gemm_add_add_fastgelu_instance OBJECT ${DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE})
target_compile_features(device_gemm_add_add_fastgelu_instance PUBLIC)
set_target_properties(device_gemm_add_add_fastgelu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_gemm_add_add_fastgelu_instance)

View File

@@ -1,13 +1,6 @@
set(DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE
add_instance_library(device_gemm_bias_add_reduce_instance
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
add_library(device_gemm_bias_add_reduce_instance OBJECT ${DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE})
target_compile_features(device_gemm_bias_add_reduce_instance PUBLIC)
set_target_properties(device_gemm_bias_add_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_gemm_bias_add_reduce_instance)

View File

@@ -1,12 +1,6 @@
# device_gemm_bilinear_instance
set(DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp;
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp;
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp;
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp;
add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
)
add_library(device_gemm_bilinear_instance OBJECT ${DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE})
set_target_properties(device_gemm_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_gemm_bilinear_instance)

View File

@@ -1,10 +1,6 @@
set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
add_instance_library(device_gemm_reduce_instance
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE})
rocm_install(TARGETS device_gemm_reduce_instance)
clang_tidy_check(device_gemm_reduce_instance)

View File

@@ -1,15 +1,10 @@
set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
add_instance_library(device_gemm_splitk_instance
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
)
add_library(device_gemm_splitk_instance OBJECT ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE})
target_compile_features(device_gemm_splitk_instance PUBLIC)
set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)

View File

@@ -1,12 +1,6 @@
# device_grouped_conv1d_fwd_instance
set(DEVICE_GROUPED_CONV1D_FWD_INSTANCE_SOURCE
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp;
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp;
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp;
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp;
add_instance_library(device_grouped_conv1d_fwd_instance
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp
)
add_library(device_grouped_conv1d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV1D_FWD_INSTANCE_SOURCE})
set_target_properties(device_grouped_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_grouped_conv1d_fwd_instance)

View File

@@ -1,15 +1,9 @@
# device_grouped_conv2d_fwd_instance
set(DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE
add_instance_library(device_grouped_conv2d_fwd_instance
# GNHWC, GKYXC, GNHWK
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp
# NHWGC, GKYXC, NHWGK
device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp;
device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
)
add_library(device_grouped_conv2d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE})
set_target_properties(device_grouped_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_grouped_conv2d_fwd_instance)

View File

@@ -1,12 +1,6 @@
# device_grouped_conv3d_fwd_instance
set(DEVICE_GROUPED_CONV3D_FWD_INSTANCE_SOURCE
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp;
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp;
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp;
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp;
add_library(device_grouped_conv3d_fwd_instance
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp
)
add_library(device_grouped_conv3d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV3D_FWD_INSTANCE_SOURCE})
set_target_properties(device_grouped_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_grouped_conv3d_fwd_instance)

View File

@@ -1,15 +1,6 @@
# device_grouped_gemm_instance
set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
)
add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE})
target_compile_features(device_grouped_gemm_instance PUBLIC)
set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_grouped_gemm_instance)
clang_tidy_check(device_grouped_gemm_instance)

View File

@@ -1,12 +1,6 @@
# device_normalization_instance
set(DEVICE_NORMALIZATION_INSTANCE_SOURCE
add_instance_library(device_normalization_instance
device_layernorm_f16_instance.cpp
device_layernorm_f32_instance.cpp
device_softmax_f32_f32_instance.cpp
device_softmax_f16_f16_instance.cpp
)
add_library(device_normalization_instance OBJECT ${DEVICE_NORMALIZATION_INSTANCE_SOURCE})
set_target_properties(device_normalization_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_normalization_instance)

View File

@@ -1,29 +1,23 @@
# device_reduce_instance
set(DEVICE_REDUCE_INSTANCE_SOURCE
device_reduce_instance_blockwise_f16_f16_f16.cpp;
device_reduce_instance_blockwise_f16_f32_f16.cpp;
device_reduce_instance_blockwise_f32_f32_f32.cpp;
device_reduce_instance_blockwise_f32_f64_f32.cpp;
device_reduce_instance_blockwise_f64_f64_f64.cpp;
device_reduce_instance_blockwise_i8_i32_i8.cpp;
device_reduce_instance_blockwise_i8_i8_i8.cpp;
device_reduce_instance_blockwise_b16_f32_b16.cpp;
device_reduce_instance_threadwise_f16_f16_f16.cpp;
device_reduce_instance_threadwise_f16_f32_f16.cpp;
device_reduce_instance_threadwise_f32_f32_f32.cpp;
device_reduce_instance_threadwise_f32_f64_f32.cpp;
device_reduce_instance_threadwise_f64_f64_f64.cpp;
device_reduce_instance_threadwise_i8_i32_i8.cpp;
device_reduce_instance_threadwise_i8_i8_i8.cpp;
device_reduce_instance_threadwise_b16_f32_b16.cpp;
device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp;
device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp;
add_instance_library(device_reduce_instance
device_reduce_instance_blockwise_f16_f16_f16.cpp
device_reduce_instance_blockwise_f16_f32_f16.cpp
device_reduce_instance_blockwise_f32_f32_f32.cpp
device_reduce_instance_blockwise_f32_f64_f32.cpp
device_reduce_instance_blockwise_f64_f64_f64.cpp
device_reduce_instance_blockwise_i8_i32_i8.cpp
device_reduce_instance_blockwise_i8_i8_i8.cpp
device_reduce_instance_blockwise_b16_f32_b16.cpp
device_reduce_instance_threadwise_f16_f16_f16.cpp
device_reduce_instance_threadwise_f16_f32_f16.cpp
device_reduce_instance_threadwise_f32_f32_f32.cpp
device_reduce_instance_threadwise_f32_f64_f32.cpp
device_reduce_instance_threadwise_f64_f64_f64.cpp
device_reduce_instance_threadwise_i8_i32_i8.cpp
device_reduce_instance_threadwise_i8_i8_i8.cpp
device_reduce_instance_threadwise_b16_f32_b16.cpp
device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp
device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp
device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp
device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp
device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp
)
add_library(device_reduce_instance OBJECT ${DEVICE_REDUCE_INSTANCE_SOURCE})
set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_reduce_instance)

View File

@@ -10,20 +10,20 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t DeviceMem::GetBufferSize() { return mMemSize; }
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
void DeviceMem::ToDevice(const void* p)
void DeviceMem::ToDevice(const void* p) const
{
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p)
void DeviceMem::FromDevice(void* p) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }