mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[HotFix] add config and version files to pass on build info (#856)
* experiment with config file * experiment with version.h config * add more info to version.h * minor updates * minor updates * fix case where DTYPE is not used * large amount of files but minor changes * remove white space * minor changes to add more MACROs * fix cmakedefine01 * fix issue with CK internal conflict * fix define and define value * fix clang-format * fix formatting issue * experiment with cmake * clang format v12 to be consistent with miopen * avoid clang-format for config file
This commit is contained in:
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -37,7 +37,7 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -58,7 +58,7 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
|
||||
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -79,7 +79,7 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
|
||||
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
|
||||
Row,
|
||||
@@ -154,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -206,7 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
@@ -59,7 +59,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
|
||||
Acc0BiasDataType::Size() == 1 &&
|
||||
@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
|
||||
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
|
||||
Acc0BiasDataType::Size() == 1 &&
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
|
||||
Row,
|
||||
@@ -124,7 +124,7 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
|
||||
Row,
|
||||
@@ -263,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -297,7 +297,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<EDataType, int8_t>)
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
@@ -59,7 +59,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
|
||||
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
// float
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
@@ -66,7 +66,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp64__
|
||||
#ifdef CK_ENABLE_FP64
|
||||
// double
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<DDataType, float> && is_same_v<EDataType, float>)
|
||||
{
|
||||
@@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp64__
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
|
||||
is_same_v<DDataType, double> && is_same_v<EDataType, double>)
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
// float
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
@@ -66,7 +66,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp64__
|
||||
#ifdef CK_ENABLE_FP64
|
||||
// double
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
@@ -149,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<EDataType, float>)
|
||||
{
|
||||
@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp64__
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
|
||||
is_same_v<EDataType, double>)
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// conv1d backward data
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<1,
|
||||
@@ -30,19 +30,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<1,
|
||||
NWC,
|
||||
@@ -55,7 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// conv2d backward data
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
@@ -69,7 +69,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -82,7 +82,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -95,7 +95,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -109,7 +109,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef DL_KERNELS
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
// conv2d dl
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
@@ -123,7 +123,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -136,7 +136,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -150,7 +150,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// conv3d backward data
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
@@ -164,7 +164,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
@@ -177,7 +177,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
@@ -190,7 +190,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
@@ -245,21 +245,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
@@ -278,7 +278,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -288,14 +288,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
@@ -314,21 +314,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d forward
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -28,7 +28,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvFwd<2,
|
||||
NHWC,
|
||||
@@ -41,13 +41,13 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvFwd<2,
|
||||
NHWC,
|
||||
@@ -103,7 +103,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -111,7 +111,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -119,7 +119,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(__fp16__) && defined(DL_KERNELS)
|
||||
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -78,7 +78,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(__fp32__) && defined(DL_KERNELS)
|
||||
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -99,7 +99,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(__int8__) && defined(DL_KERNELS)
|
||||
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -140,7 +140,7 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -161,7 +161,7 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -208,7 +208,7 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
instances);
|
||||
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -229,7 +229,7 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -270,7 +270,7 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef __fp64__
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
@@ -363,7 +363,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -414,7 +414,7 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __bf16__
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
|
||||
is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
@@ -440,7 +440,7 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
// FP16
|
||||
void add_device_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
|
||||
@@ -27,7 +27,7 @@ void add_device_normalization_rank_4_3_f16_instances(
|
||||
void add_device_normalization_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
// FP32
|
||||
void add_device_normalization_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
|
||||
@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
|
||||
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
|
||||
{
|
||||
@@ -84,7 +84,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
|
||||
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
|
||||
|
||||
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
|
||||
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
// FP16
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
@@ -37,7 +37,7 @@ void add_device_pool3d_fwd_ndhwc_index_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
// FP32
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
@@ -84,7 +84,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
|
||||
{
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
@@ -98,7 +98,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#ifdef __int8__
|
||||
#ifdef CK_ENABLE_INT8
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise(
|
||||
{
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
|
||||
[&](auto i) {
|
||||
using cfg1 = remove_cvref_t<decltype(
|
||||
std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
|
||||
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
|
||||
reduce_configuration_1_instances_blockwise{}))>;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_blockwise{}))>;
|
||||
|
||||
using ReduceOpInstance =
|
||||
DeviceReduceMultiBlock<InDataType,
|
||||
|
||||
@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
static_for<0,
|
||||
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
|
||||
1>{}([&](auto i) {
|
||||
using cfg1 = remove_cvref_t<decltype(
|
||||
std::get<i.value>(reduce_configuration_1_instances_multiblock_atomic_add{}))>;
|
||||
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
|
||||
reduce_configuration_1_instances_multiblock_atomic_add{}))>;
|
||||
|
||||
static_for<0,
|
||||
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
|
||||
1>{}([&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>;
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_multiblock_atomic_add{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise(
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>;
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_threadwise{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
|
||||
std::is_same_v<OutDataType, F16>)
|
||||
{
|
||||
@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
|
||||
std::is_same_v<OutDataType, F32>)
|
||||
{
|
||||
|
||||
@@ -102,9 +102,10 @@ struct FillMonotonicSeq
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<decltype(
|
||||
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillMonotonicSeq&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
|
||||
Reference in New Issue
Block a user