mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)
* interface for GEMM and GEMM+add+add+fastgelu
* rename namespace
* instance factory
* fix build
* fix build; add GEMM client example
* clean
[ROCm/composable_kernel commit: 0dcb3496cf]
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
@@ -116,19 +116,21 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
get_device_batched_gemm_instances<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_weight_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvBwdWeightNoOpPtr =
|
||||
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdWeightNoOpPtr>&);
|
||||
|
||||
} // namespace device_conv2d_bwd_weight_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_bias_activation_add_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvFwdBiasReluAddPtr =
|
||||
DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvFwdBiasReluAddPtr>&);
|
||||
|
||||
} // namespace device_conv2d_fwd_bias_activation_add_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_bias_activation_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvFwdBiasReluPtr =
|
||||
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvFwdBiasReluPtr>&);
|
||||
|
||||
} // namespace device_conv2d_fwd_bias_activation_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ using INT8 = int8_t;
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_bwd_data_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
|
||||
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
|
||||
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
|
||||
@@ -10,13 +10,12 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/host_tensor/host_conv.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -30,9 +29,7 @@ template <typename ADataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
typename DELayout> // assume Ds and E have same layout
|
||||
bool profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
int init_method,
|
||||
bool /*do_log*/,
|
||||
@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_add_add_fastgelu_instances<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>();
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddAddFastGelu>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
|
||||
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmAlphaBetaPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAlphaBetaPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmAlphaBetaPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
|
||||
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasReluAddPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
|
||||
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluAddPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluAddPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasReluPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
|
||||
c0_n_device_buf.ToDevice(c0_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_instances<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>();
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs =
|
||||
ck::tensor_operation::device::device_gemm_instance::get_device_gemm_splitk_instances<
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device operation instance found");
|
||||
}
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_grouped_gemm_instance {
|
||||
namespace instance {
|
||||
|
||||
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGroupedGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_grouped_gemm_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<
|
||||
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
|
||||
gemm_ptrs;
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_grouped_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_grouped_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_grouped_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_grouped_gemm_instance::
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_normalization_instance {
|
||||
namespace instance {
|
||||
|
||||
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
|
||||
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
|
||||
} // namespace device_normalization_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f16_f16_rank3_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f16_f16_rank4_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f32_f32_rank3_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f32_f32_rank4_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
namespace instance {
|
||||
|
||||
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex>
|
||||
struct ReduceDescription
|
||||
@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description,
|
||||
return (result);
|
||||
};
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
float beta)
|
||||
{
|
||||
using namespace ck::tensor_operation::device;
|
||||
using namespace ck::tensor_operation::device::device_reduce_instance;
|
||||
using namespace ck::tensor_operation::device::instance;
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
constexpr bool op_support_indices =
|
||||
@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification,
|
||||
bool pass = true;
|
||||
|
||||
using tuple_of_description_instances =
|
||||
tensor_operation::device::device_reduce_instance::reduce_description_instances;
|
||||
tensor_operation::device::instance::reduce_description_instances;
|
||||
|
||||
const auto tuple_object = tuple_of_description_instances{};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user