include(gtest)

# Helper function to create a gtest executable with common properties
function(add_ck_builder_test test_name)
    add_executable(${test_name} ${ARGN} testing_utils.cpp)
    target_compile_features(${test_name} PRIVATE cxx_std_20)
    target_include_directories(${test_name} PRIVATE
        "${PROJECT_SOURCE_DIR}/experimental/builder/include"
        "${PROJECT_SOURCE_DIR}/include"
    )
    target_compile_options(${test_name} PRIVATE
        -Wno-global-constructors
        -Wno-c++20-compat
    )
    target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
endfunction()

add_ck_builder_test(test_conv_builder
    test_conv_builder.cpp
    test_instance_traits.cpp
    test_instance_traits_util.cpp
    testing_utils.cpp)

add_ck_builder_test(test_get_instance_string
    test_get_instance_string.cpp)

add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)

function(add_ck_factory_test test_name)
    add_ck_builder_test(${test_name} ${ARGN})
    target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
endfunction()

add_ck_factory_test(test_testing_utils test_testing_utils.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bilinear test_ck_factory_grouped_convolution_forward_bilinear.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scale test_ck_factory_grouped_convolution_forward_scale.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_clamp test_ck_factory_grouped_convolution_forward_bias_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp.cpp)
add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
