diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7c4e9684..20c469d7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -find_package(MPI) +find_package(MPI REQUIRED) set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} Threads::Threads) if(MSCCLPP_USE_IB) @@ -40,7 +40,7 @@ include(CTest) # Build test framework library add_library(test_framework STATIC framework.cc) -target_include_directories(test_framework PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(test_framework PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${TEST_INC_COMMON}) target_link_libraries(test_framework PUBLIC MPI::MPI_CXX) # Unit tests diff --git a/test/framework.hpp b/test/framework.hpp index 4b953e37..cfd9ecf6 100644 --- a/test/framework.hpp +++ b/test/framework.hpp @@ -366,11 +366,34 @@ void reportSuccess(); } \ } while (0) -#define FAIL() \ - do { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \ - throw std::runtime_error("Test failed"); \ - } while (0) +// Helper class for FAIL functionality with message streaming support +class FailHelper { + public: + explicit FailHelper(const char* file, int line) : file_(file), line_(line) {} + template + FailHelper& operator<<(const T& value) { + message_ << value; + return *this; + } + ~FailHelper() noexcept(false) { + std::string msg = message_.str(); + if (!msg.empty()) { + ::mscclpp::test::utils::reportFailure(file_, line_, "Test failed: " + msg); + } else { + ::mscclpp::test::utils::reportFailure(file_, line_, "Test failed"); + } + throw std::runtime_error("Test failed"); + } + + private: + const char* file_; + int line_; + std::ostringstream message_; +}; + +// Test fail macro - throws exception to fail test execution +// Usage: FAIL() << "Optional fail message"; +#define FAIL() ::mscclpp::test::FailHelper(__FILE__, __LINE__) // Helper class for GTEST_SKIP functionality // This class uses RAII (Resource Acquisition Is Initialization) pattern: diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 8b1fab27..bcf880ae 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -15,7 +15,7 @@ #include "ib.hpp" #include "utils_internal.hpp" -class MultiProcessTestEnv : public ::testing::Environment { +class MultiProcessTestEnv : public ::mscclpp::test::Environment { public: MultiProcessTestEnv(int argc, const char** argv); @@ -36,7 +36,7 @@ mscclpp::Transport ibIdToTransport(int id); int rankToLocalRank(int rank); int rankToNode(int rank); -class MultiProcessTest : public ::testing::Test { +class MultiProcessTest : public ::mscclpp::test::TestCase { protected: void TearDown() override; }; diff --git a/test/unit/core_tests.cc b/test/unit/core_tests.cc index a2c39c1b..13437872 100644 --- a/test/unit/core_tests.cc +++ b/test/unit/core_tests.cc @@ -1,13 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include +#include "../framework.hpp" #include -#include "../framework.hpp" - -class LocalCommunicatorTest : public ::testing::Test { +class LocalCommunicatorTest : public ::mscclpp::test::TestCase { protected: void SetUp() override { bootstrap = std::make_shared(0, 1);