From e227fdc1ef5777441c0ef2c8485a10eeb3cff32f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Feb 2026 00:21:04 +0000 Subject: [PATCH] Convert mp_unit tests from gtest to framework.hpp - Modified test/mp_unit/mp_unit_tests.hpp to use ../framework.hpp instead of gtest/gtest.h - Enhanced test/framework.hpp with GTest-compatible APIs: - Added Environment base class for global test setup/teardown - Added TestInfo and UnitTest classes for test metadata access - Added GTEST_SKIP macro support via SkipHelper class - Added namespace alias 'testing' for compatibility - Added InitGoogleTest and AddGlobalTestEnvironment helper functions - Updated test/framework.cc with implementations for new classes - All mp_unit test files now use framework.hpp through mp_unit_tests.hpp - Formatting applied via lint.sh Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com> --- test/executor_test.cc | 7 +- test/framework.cc | 45 +++- test/framework.hpp | 409 ++++++++++++++++++------------ test/mp_unit/mp_unit_tests.hpp | 3 +- test/perf/framework.cc | 4 +- test/perf/framework.hpp | 4 +- test/unit/core_tests.cc | 3 +- test/unit/errors_tests.cc | 4 +- test/unit/fifo_tests.cu | 3 +- test/unit/gpu_utils_tests.cc | 4 +- test/unit/local_channel_tests.cu | 4 +- test/unit/numa_tests.cc | 4 +- test/unit/socket_tests.cc | 3 +- test/unit/utils_internal_tests.cc | 3 +- test/unit/utils_tests.cc | 4 +- 15 files changed, 310 insertions(+), 194 deletions(-) diff --git a/test/executor_test.cc b/test/executor_test.cc index 0e7869ab..cc745659 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -93,11 +93,8 @@ double benchTime(int rank, std::shared_ptr bootstrap, std::s int main(int argc, char* argv[]) { if (argc != 5 && argc != 6) { - std::cerr << "Usage: " << argv[0] << " " - << " " - << " " - << " " - << " (optional) " << std::endl; + std::cerr << "Usage: " << argv[0] << " " << " " << " " + << " " << " (optional) " << std::endl; return 1; } diff --git a/test/framework.cc b/test/framework.cc index 5fd096f1..fc339b76 100644 --- a/test/framework.cc +++ b/test/framework.cc @@ -161,6 +161,12 @@ int runMultipleTests( } // namespace utils +// UnitTest implementation +UnitTest* UnitTest::GetInstance() { + static UnitTest instance; + return &instance; +} + // TestRegistry implementation TestRegistry& TestRegistry::instance() { static TestRegistry registry; @@ -168,19 +174,38 @@ TestRegistry& TestRegistry::instance() { } void TestRegistry::registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory) { - TestInfo info; + TestInfoInternal info; info.suite_name = test_suite; info.test_name = test_name; info.factory = factory; tests_.push_back(info); } +void TestRegistry::addGlobalTestEnvironment(Environment* env) { environments_.push_back(env); } + +void TestRegistry::initGoogleTest(int* argc, char** argv) { + // Parse command-line arguments if needed + // For now, this is a no-op placeholder for compatibility +} + int TestRegistry::runAllTests(int argc, char* argv[]) { // Initialize MPI if not already initialized if (!g_mpi_initialized) { utils::initializeMPI(argc, argv); } + // Set up global test environments + for (auto* env : environments_) { + try { + env->SetUp(); + } catch (const std::exception& e) { + if (g_mpi_rank == 0) { + std::cerr << "Failed to set up test environment: " << e.what() << std::endl; + } + return 1; + } + } + int passed = 0; int failed = 0; @@ -196,6 +221,10 @@ int TestRegistry::runAllTests(int argc, char* argv[]) { std::cout << "[ RUN ] " << test_info.suite_name << "." << test_info.test_name << std::endl; } + // Set current test info for UnitTest::GetInstance()->current_test_info() + TestInfo current_info(test_info.suite_name, test_info.test_name); + UnitTest::GetInstance()->set_current_test_info(¤t_info); + TestCase* test_case = nullptr; try { test_case = test_info.factory(); @@ -216,6 +245,9 @@ int TestRegistry::runAllTests(int argc, char* argv[]) { delete test_case; + // Clear current test info + UnitTest::GetInstance()->set_current_test_info(nullptr); + // Synchronize test status across all MPI processes int local_passed = g_current_test_passed ? 1 : 0; int global_passed = 1; @@ -246,6 +278,17 @@ int TestRegistry::runAllTests(int argc, char* argv[]) { } } + // Tear down global test environments (in reverse order) + for (auto it = environments_.rbegin(); it != environments_.rend(); ++it) { + try { + (*it)->TearDown(); + } catch (const std::exception& e) { + if (g_mpi_rank == 0) { + std::cerr << "Failed to tear down test environment: " << e.what() << std::endl; + } + } + } + return failed > 0 ? 1 : 0; } diff --git a/test/framework.hpp b/test/framework.hpp index 6d510382..1ef9aaea 100644 --- a/test/framework.hpp +++ b/test/framework.hpp @@ -33,6 +33,12 @@ struct TestResult { std::string failure_message; }; +// Forward declarations +class Environment; +class TestCase; +class TestInfo; +class UnitTest; + // Test case base class class TestCase { public: @@ -42,24 +48,61 @@ class TestCase { virtual void TestBody() = 0; }; +// Environment base class (for global test setup/teardown) +class Environment { + public: + virtual ~Environment() = default; + virtual void SetUp() {} + virtual void TearDown() {} +}; + +// Test info class (for getting current test information) +class TestInfo { + public: + TestInfo(const std::string& suite, const std::string& name) : test_suite_name_(suite), test_name_(name) {} + + const char* test_suite_name() const { return test_suite_name_.c_str(); } + const char* name() const { return test_name_.c_str(); } + + private: + std::string test_suite_name_; + std::string test_name_; +}; + +// UnitTest singleton (for getting test information) +class UnitTest { + public: + static UnitTest* GetInstance(); + + const TestInfo* current_test_info() const { return current_test_info_; } + void set_current_test_info(const TestInfo* info) { current_test_info_ = info; } + + private: + UnitTest() = default; + const TestInfo* current_test_info_ = nullptr; +}; + // Test registry and runner class TestRegistry { public: using TestFactory = std::function; - + static TestRegistry& instance(); - + void registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory); + void addGlobalTestEnvironment(Environment* env); int runAllTests(int argc, char* argv[]); - + void initGoogleTest(int* argc, char** argv); + private: TestRegistry() = default; - struct TestInfo { + struct TestInfoInternal { std::string suite_name; std::string test_name; TestFactory factory; }; - std::vector tests_; + std::vector tests_; + std::vector environments_; }; // Simple utility functions for testing @@ -107,230 +150,266 @@ void reportSuccess(); } // namespace mscclpp // Test registration macros -#define TEST(test_suite, test_name) \ - class test_suite##_##test_name##_Test : public ::mscclpp::test::TestCase { \ - public: \ - test_suite##_##test_name##_Test() {} \ - void TestBody() override; \ - }; \ - static bool test_suite##_##test_name##_registered = []() { \ - ::mscclpp::test::TestRegistry::instance().registerTest( \ - #test_suite, #test_name, \ +#define TEST(test_suite, test_name) \ + class test_suite##_##test_name##_Test : public ::mscclpp::test::TestCase { \ + public: \ + test_suite##_##test_name##_Test() {} \ + void TestBody() override; \ + }; \ + static bool test_suite##_##test_name##_registered = []() { \ + ::mscclpp::test::TestRegistry::instance().registerTest( \ + #test_suite, #test_name, \ []() -> ::mscclpp::test::TestCase* { return new test_suite##_##test_name##_Test(); }); \ - return true; \ - }(); \ + return true; \ + }(); \ void test_suite##_##test_name##_Test::TestBody() -#define TEST_F(test_fixture, test_name) \ - class test_fixture##_##test_name##_Test : public test_fixture { \ - public: \ - test_fixture##_##test_name##_Test() {} \ - void TestBody() override; \ - }; \ - static bool test_fixture##_##test_name##_registered = []() { \ - ::mscclpp::test::TestRegistry::instance().registerTest( \ - #test_fixture, #test_name, \ +#define TEST_F(test_fixture, test_name) \ + class test_fixture##_##test_name##_Test : public test_fixture { \ + public: \ + test_fixture##_##test_name##_Test() {} \ + void TestBody() override; \ + }; \ + static bool test_fixture##_##test_name##_registered = []() { \ + ::mscclpp::test::TestRegistry::instance().registerTest( \ + #test_fixture, #test_name, \ []() -> ::mscclpp::test::TestCase* { return new test_fixture##_##test_name##_Test(); }); \ - return true; \ - }(); \ + return true; \ + }(); \ void test_fixture##_##test_name##_Test::TestBody() // Test runner macro #define RUN_ALL_TESTS() ::mscclpp::test::TestRegistry::instance().runAllTests(argc, argv) // Assertion macros -#define EXPECT_TRUE(condition) \ - do { \ - if (!(condition)) { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \ - "Expected: " #condition " to be true"); \ - } \ +#define EXPECT_TRUE(condition) \ + do { \ + if (!(condition)) { \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be true"); \ + } \ } while (0) -#define EXPECT_FALSE(condition) \ - do { \ - if (condition) { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \ - "Expected: " #condition " to be false"); \ - } \ +#define EXPECT_FALSE(condition) \ + do { \ + if (condition) { \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be false"); \ + } \ } while (0) -#define EXPECT_EQ(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 == v2)) { \ - std::ostringstream oss; \ +#define EXPECT_EQ(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 == v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ + } \ } while (0) -#define EXPECT_NE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 != v2)) { \ - std::ostringstream oss; \ +#define EXPECT_NE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 != v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ + } \ } while (0) -#define EXPECT_LT(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 < v2)) { \ - std::ostringstream oss; \ - oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ +#define EXPECT_LT(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 < v2)) { \ + std::ostringstream oss; \ + oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ + } \ } while (0) -#define EXPECT_LE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 <= v2)) { \ - std::ostringstream oss; \ +#define EXPECT_LE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 <= v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ + } \ } while (0) -#define EXPECT_GT(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 > v2)) { \ - std::ostringstream oss; \ - oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ +#define EXPECT_GT(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 > v2)) { \ + std::ostringstream oss; \ + oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ + } \ } while (0) -#define EXPECT_GE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 >= v2)) { \ - std::ostringstream oss; \ +#define EXPECT_GE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 >= v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - } \ + } \ } while (0) -#define ASSERT_TRUE(condition) \ - do { \ - if (!(condition)) { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \ - "Expected: " #condition " to be true"); \ - throw std::runtime_error("Test assertion failed"); \ - } \ +#define ASSERT_TRUE(condition) \ + do { \ + if (!(condition)) { \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be true"); \ + throw std::runtime_error("Test assertion failed"); \ + } \ } while (0) -#define ASSERT_FALSE(condition) \ - do { \ - if (condition) { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \ - "Expected: " #condition " to be false"); \ - throw std::runtime_error("Test assertion failed"); \ - } \ +#define ASSERT_FALSE(condition) \ + do { \ + if (condition) { \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be false"); \ + throw std::runtime_error("Test assertion failed"); \ + } \ } while (0) -#define ASSERT_EQ(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 == v2)) { \ - std::ostringstream oss; \ +#define ASSERT_EQ(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 == v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ throw std::runtime_error("Test assertion failed"); \ - } \ + } \ } while (0) -#define ASSERT_NE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 != v2)) { \ - std::ostringstream oss; \ +#define ASSERT_NE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 != v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ throw std::runtime_error("Test assertion failed"); \ - } \ + } \ } while (0) -#define ASSERT_LT(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 < v2)) { \ - std::ostringstream oss; \ - oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - throw std::runtime_error("Test assertion failed"); \ - } \ +#define ASSERT_LT(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 < v2)) { \ + std::ostringstream oss; \ + oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ + throw std::runtime_error("Test assertion failed"); \ + } \ } while (0) -#define ASSERT_LE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 <= v2)) { \ - std::ostringstream oss; \ +#define ASSERT_LE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 <= v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ throw std::runtime_error("Test assertion failed"); \ - } \ + } \ } while (0) -#define ASSERT_GT(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 > v2)) { \ - std::ostringstream oss; \ - oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - throw std::runtime_error("Test assertion failed"); \ - } \ +#define ASSERT_GT(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 > v2)) { \ + std::ostringstream oss; \ + oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ + throw std::runtime_error("Test assertion failed"); \ + } \ } while (0) -#define ASSERT_GE(val1, val2) \ - do { \ - auto v1 = (val1); \ - auto v2 = (val2); \ - if (!(v1 >= v2)) { \ - std::ostringstream oss; \ +#define ASSERT_GE(val1, val2) \ + do { \ + auto v1 = (val1); \ + auto v2 = (val2); \ + if (!(v1 >= v2)) { \ + std::ostringstream oss; \ oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \ ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ throw std::runtime_error("Test assertion failed"); \ - } \ + } \ } while (0) -#define ASSERT_NO_THROW(statement) \ - do { \ - try { \ - statement; \ - } catch (const std::exception& e) { \ - std::ostringstream oss; \ - oss << "Expected: " #statement " not to throw\n Actual: threw " << e.what(); \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ - throw std::runtime_error("Test assertion failed"); \ - } catch (...) { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \ - "Expected: " #statement " not to throw\n Actual: threw unknown exception"); \ - throw std::runtime_error("Test assertion failed"); \ - } \ +#define ASSERT_NO_THROW(statement) \ + do { \ + try { \ + statement; \ + } catch (const std::exception& e) { \ + std::ostringstream oss; \ + oss << "Expected: " #statement " not to throw\n Actual: threw " << e.what(); \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \ + throw std::runtime_error("Test assertion failed"); \ + } catch (...) { \ + ::mscclpp::test::utils::reportFailure( \ + __FILE__, __LINE__, "Expected: " #statement " not to throw\n Actual: threw unknown exception"); \ + throw std::runtime_error("Test assertion failed"); \ + } \ } while (0) -#define FAIL() \ - do { \ - ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \ - throw std::runtime_error("Test failed"); \ +#define FAIL() \ + do { \ + ::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \ + throw std::runtime_error("Test failed"); \ } while (0) +// Helper class for GTEST_SKIP functionality +class SkipHelper { + public: + explicit SkipHelper(const char* file, int line) : file_(file), line_(line) {} + template + SkipHelper& operator<<(const T& value) { + message_ << value; + return *this; + } + ~SkipHelper() noexcept(false) { + std::string msg = message_.str(); + if (!msg.empty()) { + ::mscclpp::test::utils::reportFailure(file_, line_, "Test skipped: " + msg); + } else { + ::mscclpp::test::utils::reportFailure(file_, line_, "Test skipped"); + } + throw std::runtime_error("Test skipped"); + } + + private: + const char* file_; + int line_; + std::ostringstream message_; +}; + +#define GTEST_SKIP() ::SkipHelper(__FILE__, __LINE__) + +// Create a namespace alias for compatibility with GTest code +namespace testing = ::mscclpp::test; + +// Helper functions for compatibility with GTest API +inline void InitGoogleTest(int* argc, char** argv) { + ::mscclpp::test::TestRegistry::instance().initGoogleTest(argc, argv); +} + +inline ::mscclpp::test::Environment* AddGlobalTestEnvironment(::mscclpp::test::Environment* env) { + ::mscclpp::test::TestRegistry::instance().addGlobalTestEnvironment(env); + return env; +} + #endif // MSCCLPP_TEST_FRAMEWORK_HPP_ diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 17046a57..8b1fab27 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -4,8 +4,6 @@ #ifndef MSCCLPP_MP_UNIT_TESTS_HPP_ #define MSCCLPP_MP_UNIT_TESTS_HPP_ -#include - #include #include #include @@ -13,6 +11,7 @@ #include #include +#include "../framework.hpp" #include "ib.hpp" #include "utils_internal.hpp" diff --git a/test/perf/framework.cc b/test/perf/framework.cc index 600257d1..0b011cc5 100644 --- a/test/perf/framework.cc +++ b/test/perf/framework.cc @@ -12,7 +12,7 @@ namespace mscclpp { namespace test { // Global state for performance test results -static std::vector test_params; @@ -20,7 +20,7 @@ static std::vector g_perf_results; +} > g_perf_results; static std::string getCurrentTimestamp() { auto now = std::chrono::system_clock::now(); diff --git a/test/perf/framework.hpp b/test/perf/framework.hpp index fe49be91..094d5cb1 100644 --- a/test/perf/framework.hpp +++ b/test/perf/framework.hpp @@ -7,10 +7,10 @@ // This file is kept for backwards compatibility with perf tests // The actual framework is now in test/framework.hpp -#include "../framework.hpp" - #include +#include "../framework.hpp" + namespace mscclpp { namespace test { diff --git a/test/unit/core_tests.cc b/test/unit/core_tests.cc index 1c8ee886..a2c39c1b 100644 --- a/test/unit/core_tests.cc +++ b/test/unit/core_tests.cc @@ -2,10 +2,11 @@ // Licensed under the MIT license. #include -#include "../framework.hpp" #include +#include "../framework.hpp" + class LocalCommunicatorTest : public ::testing::Test { protected: void SetUp() override { diff --git a/test/unit/errors_tests.cc b/test/unit/errors_tests.cc index 8d6283d9..4cd68ee6 100644 --- a/test/unit/errors_tests.cc +++ b/test/unit/errors_tests.cc @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include +#include "../framework.hpp" + TEST(ErrorsTest, SystemError) { mscclpp::Error error("test", mscclpp::ErrorCode::SystemError); EXPECT_EQ(error.getErrorCode(), mscclpp::ErrorCode::SystemError); diff --git a/test/unit/fifo_tests.cu b/test/unit/fifo_tests.cu index a0cf5447..68e777d0 100644 --- a/test/unit/fifo_tests.cu +++ b/test/unit/fifo_tests.cu @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include #include #include #include +#include "../framework.hpp" #include "utils_internal.hpp" #define ITER 10000 // should be larger than the FIFO size for proper testing diff --git a/test/unit/gpu_utils_tests.cc b/test/unit/gpu_utils_tests.cc index dc4027a1..c10f113c 100644 --- a/test/unit/gpu_utils_tests.cc +++ b/test/unit/gpu_utils_tests.cc @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include +#include "../framework.hpp" + TEST(GpuUtilsTest, StreamPool) { auto streamPool = mscclpp::gpuStreamPool(); cudaStream_t s; diff --git a/test/unit/local_channel_tests.cu b/test/unit/local_channel_tests.cu index d7cd4c65..76060f97 100644 --- a/test/unit/local_channel_tests.cu +++ b/test/unit/local_channel_tests.cu @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include #include #include #include +#include "../framework.hpp" + #define MAGIC_CONST 777 __constant__ mscclpp::PortChannelDeviceHandle gPortChannel; diff --git a/test/unit/numa_tests.cc b/test/unit/numa_tests.cc index 31ba373c..c27fde90 100644 --- a/test/unit/numa_tests.cc +++ b/test/unit/numa_tests.cc @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include #include +#include "../framework.hpp" + TEST(NumaTest, Basic) { int num; MSCCLPP_CUDATHROW(cudaGetDeviceCount(&num)); diff --git a/test/unit/socket_tests.cc b/test/unit/socket_tests.cc index cfd5bd4f..6b7c1903 100644 --- a/test/unit/socket_tests.cc +++ b/test/unit/socket_tests.cc @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include #include +#include "../framework.hpp" #include "socket.h" #include "utils_internal.hpp" diff --git a/test/unit/utils_internal_tests.cc b/test/unit/utils_internal_tests.cc index 73b03833..8526d9fe 100644 --- a/test/unit/utils_internal_tests.cc +++ b/test/unit/utils_internal_tests.cc @@ -1,10 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#include "../framework.hpp" - #include +#include "../framework.hpp" #include "utils_internal.hpp" TEST(UtilsInternalTest, getHostHash) { diff --git a/test/unit/utils_tests.cc b/test/unit/utils_tests.cc index ae77892d..110550da 100644 --- a/test/unit/utils_tests.cc +++ b/test/unit/utils_tests.cc @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "../framework.hpp" - #include #include #include +#include "../framework.hpp" + TEST(UtilsTest, getHostName) { std::string hostname1 = mscclpp::getHostName(1024, '.'); EXPECT_FALSE(hostname1.empty());