mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Fix Tests build (#109)
* fix tests * remove useless file * fix test build * reduce parallelism when compiling * fix test
This commit is contained in:
@@ -11,8 +11,9 @@
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ushort;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -22,6 +23,7 @@ using DeviceConvBwdDataNoOpPtr =
|
||||
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
@@ -30,6 +32,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
|
||||
} // namespace device_conv2d_bwd_data_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
@@ -78,7 +81,12 @@ int main(int argc, char* argv[])
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 3)
|
||||
if(argc == 1)
|
||||
{
|
||||
data_type = 1;
|
||||
init_method = 1;
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
data_type = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -106,11 +114,9 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: data type (0=fp32 )\n");
|
||||
printf("arg2: verification (0=no, 1=yes)\n");
|
||||
printf("arg3: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg4: run kernel # of times (>1)\n");
|
||||
printf("arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
@@ -296,7 +302,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(data_type == 0)
|
||||
{
|
||||
Run(float(), float(), F32());
|
||||
Run(F32(), F32(), F32());
|
||||
}
|
||||
else if(data_type == 1)
|
||||
{
|
||||
Reference in New Issue
Block a user