mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)
* interface for GEMM and GEMM+add+add+fastgelu * rename namespace * instance factory * fix build * fix build; add GEMM client example * clean
This commit is contained in:
@@ -22,7 +22,7 @@ using INT8 = int8_t;
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
|
||||
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
|
||||
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
|
||||
Reference in New Issue
Block a user