Fix Tests build (#109)

* fix tests

* remove useless file

* fix test build

* reduce parallelism when compiling

* fix test
This commit is contained in:
Chao Liu
2022-03-05 00:44:11 -06:00
committed by GitHub
parent 7a9b93f4b6
commit 5b178874a1
12 changed files with 21 additions and 33 deletions

View File

@@ -11,8 +11,9 @@
using F16 = ck::half_t;
using F32 = float;
using BF16 = ushort;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
namespace ck {
namespace tensor_operation {
namespace device {
@@ -22,6 +23,7 @@ using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
@@ -30,6 +32,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance
} // namespace device
} // namespace tensor_operation
@@ -78,7 +81,12 @@ int main(int argc, char* argv[])
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 3)
if(argc == 1)
{
data_type = 1;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -106,11 +114,9 @@ int main(int argc, char* argv[])
}
else
{
printf("arg1: data type (0=fp32 )\n");
printf("arg2: verification (0=no, 1=yes)\n");
printf("arg3: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg4: run kernel # of times (>1)\n");
printf("arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
@@ -296,7 +302,7 @@ int main(int argc, char* argv[])
if(data_type == 0)
{
Run(float(), float(), F32());
Run(F32(), F32(), F32());
}
else if(data_type == 1)
{