mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Convert mp_unit tests from gtest to framework.hpp
- Modified test/mp_unit/mp_unit_tests.hpp to use ../framework.hpp instead of gtest/gtest.h - Enhanced test/framework.hpp with GTest-compatible APIs: - Added Environment base class for global test setup/teardown - Added TestInfo and UnitTest classes for test metadata access - Added GTEST_SKIP macro support via SkipHelper class - Added namespace alias 'testing' for compatibility - Added InitGoogleTest and AddGlobalTestEnvironment helper functions - Updated test/framework.cc with implementations for new classes - All mp_unit test files now use framework.hpp through mp_unit_tests.hpp - Formatting applied via lint.sh Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
This commit is contained in:
@@ -93,11 +93,8 @@ double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::s
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 5 && argc != 6) {
|
||||
std::cerr << "Usage: " << argv[0] << " <buffer size>"
|
||||
<< " <execution plan path>"
|
||||
<< " <number of iterations>"
|
||||
<< " <number of graph iterations>"
|
||||
<< " (optional) <packet type>" << std::endl;
|
||||
std::cerr << "Usage: " << argv[0] << " <buffer size>" << " <execution plan path>" << " <number of iterations>"
|
||||
<< " <number of graph iterations>" << " (optional) <packet type>" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -161,6 +161,12 @@ int runMultipleTests(
|
||||
|
||||
} // namespace utils
|
||||
|
||||
// UnitTest implementation
|
||||
UnitTest* UnitTest::GetInstance() {
|
||||
static UnitTest instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
// TestRegistry implementation
|
||||
TestRegistry& TestRegistry::instance() {
|
||||
static TestRegistry registry;
|
||||
@@ -168,19 +174,38 @@ TestRegistry& TestRegistry::instance() {
|
||||
}
|
||||
|
||||
void TestRegistry::registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory) {
|
||||
TestInfo info;
|
||||
TestInfoInternal info;
|
||||
info.suite_name = test_suite;
|
||||
info.test_name = test_name;
|
||||
info.factory = factory;
|
||||
tests_.push_back(info);
|
||||
}
|
||||
|
||||
void TestRegistry::addGlobalTestEnvironment(Environment* env) { environments_.push_back(env); }
|
||||
|
||||
void TestRegistry::initGoogleTest(int* argc, char** argv) {
|
||||
// Parse command-line arguments if needed
|
||||
// For now, this is a no-op placeholder for compatibility
|
||||
}
|
||||
|
||||
int TestRegistry::runAllTests(int argc, char* argv[]) {
|
||||
// Initialize MPI if not already initialized
|
||||
if (!g_mpi_initialized) {
|
||||
utils::initializeMPI(argc, argv);
|
||||
}
|
||||
|
||||
// Set up global test environments
|
||||
for (auto* env : environments_) {
|
||||
try {
|
||||
env->SetUp();
|
||||
} catch (const std::exception& e) {
|
||||
if (g_mpi_rank == 0) {
|
||||
std::cerr << "Failed to set up test environment: " << e.what() << std::endl;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
int passed = 0;
|
||||
int failed = 0;
|
||||
|
||||
@@ -196,6 +221,10 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
|
||||
std::cout << "[ RUN ] " << test_info.suite_name << "." << test_info.test_name << std::endl;
|
||||
}
|
||||
|
||||
// Set current test info for UnitTest::GetInstance()->current_test_info()
|
||||
TestInfo current_info(test_info.suite_name, test_info.test_name);
|
||||
UnitTest::GetInstance()->set_current_test_info(¤t_info);
|
||||
|
||||
TestCase* test_case = nullptr;
|
||||
try {
|
||||
test_case = test_info.factory();
|
||||
@@ -216,6 +245,9 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
|
||||
|
||||
delete test_case;
|
||||
|
||||
// Clear current test info
|
||||
UnitTest::GetInstance()->set_current_test_info(nullptr);
|
||||
|
||||
// Synchronize test status across all MPI processes
|
||||
int local_passed = g_current_test_passed ? 1 : 0;
|
||||
int global_passed = 1;
|
||||
@@ -246,6 +278,17 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tear down global test environments (in reverse order)
|
||||
for (auto it = environments_.rbegin(); it != environments_.rend(); ++it) {
|
||||
try {
|
||||
(*it)->TearDown();
|
||||
} catch (const std::exception& e) {
|
||||
if (g_mpi_rank == 0) {
|
||||
std::cerr << "Failed to tear down test environment: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return failed > 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,12 @@ struct TestResult {
|
||||
std::string failure_message;
|
||||
};
|
||||
|
||||
// Forward declarations
|
||||
class Environment;
|
||||
class TestCase;
|
||||
class TestInfo;
|
||||
class UnitTest;
|
||||
|
||||
// Test case base class
|
||||
class TestCase {
|
||||
public:
|
||||
@@ -42,24 +48,61 @@ class TestCase {
|
||||
virtual void TestBody() = 0;
|
||||
};
|
||||
|
||||
// Environment base class (for global test setup/teardown)
|
||||
class Environment {
|
||||
public:
|
||||
virtual ~Environment() = default;
|
||||
virtual void SetUp() {}
|
||||
virtual void TearDown() {}
|
||||
};
|
||||
|
||||
// Test info class (for getting current test information)
|
||||
class TestInfo {
|
||||
public:
|
||||
TestInfo(const std::string& suite, const std::string& name) : test_suite_name_(suite), test_name_(name) {}
|
||||
|
||||
const char* test_suite_name() const { return test_suite_name_.c_str(); }
|
||||
const char* name() const { return test_name_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string test_suite_name_;
|
||||
std::string test_name_;
|
||||
};
|
||||
|
||||
// UnitTest singleton (for getting test information)
|
||||
class UnitTest {
|
||||
public:
|
||||
static UnitTest* GetInstance();
|
||||
|
||||
const TestInfo* current_test_info() const { return current_test_info_; }
|
||||
void set_current_test_info(const TestInfo* info) { current_test_info_ = info; }
|
||||
|
||||
private:
|
||||
UnitTest() = default;
|
||||
const TestInfo* current_test_info_ = nullptr;
|
||||
};
|
||||
|
||||
// Test registry and runner
|
||||
class TestRegistry {
|
||||
public:
|
||||
using TestFactory = std::function<TestCase*()>;
|
||||
|
||||
|
||||
static TestRegistry& instance();
|
||||
|
||||
|
||||
void registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory);
|
||||
void addGlobalTestEnvironment(Environment* env);
|
||||
int runAllTests(int argc, char* argv[]);
|
||||
|
||||
void initGoogleTest(int* argc, char** argv);
|
||||
|
||||
private:
|
||||
TestRegistry() = default;
|
||||
struct TestInfo {
|
||||
struct TestInfoInternal {
|
||||
std::string suite_name;
|
||||
std::string test_name;
|
||||
TestFactory factory;
|
||||
};
|
||||
std::vector<TestInfo> tests_;
|
||||
std::vector<TestInfoInternal> tests_;
|
||||
std::vector<Environment*> environments_;
|
||||
};
|
||||
|
||||
// Simple utility functions for testing
|
||||
@@ -107,230 +150,266 @@ void reportSuccess();
|
||||
} // namespace mscclpp
|
||||
|
||||
// Test registration macros
|
||||
#define TEST(test_suite, test_name) \
|
||||
class test_suite##_##test_name##_Test : public ::mscclpp::test::TestCase { \
|
||||
public: \
|
||||
test_suite##_##test_name##_Test() {} \
|
||||
void TestBody() override; \
|
||||
}; \
|
||||
static bool test_suite##_##test_name##_registered = []() { \
|
||||
::mscclpp::test::TestRegistry::instance().registerTest( \
|
||||
#test_suite, #test_name, \
|
||||
#define TEST(test_suite, test_name) \
|
||||
class test_suite##_##test_name##_Test : public ::mscclpp::test::TestCase { \
|
||||
public: \
|
||||
test_suite##_##test_name##_Test() {} \
|
||||
void TestBody() override; \
|
||||
}; \
|
||||
static bool test_suite##_##test_name##_registered = []() { \
|
||||
::mscclpp::test::TestRegistry::instance().registerTest( \
|
||||
#test_suite, #test_name, \
|
||||
[]() -> ::mscclpp::test::TestCase* { return new test_suite##_##test_name##_Test(); }); \
|
||||
return true; \
|
||||
}(); \
|
||||
return true; \
|
||||
}(); \
|
||||
void test_suite##_##test_name##_Test::TestBody()
|
||||
|
||||
#define TEST_F(test_fixture, test_name) \
|
||||
class test_fixture##_##test_name##_Test : public test_fixture { \
|
||||
public: \
|
||||
test_fixture##_##test_name##_Test() {} \
|
||||
void TestBody() override; \
|
||||
}; \
|
||||
static bool test_fixture##_##test_name##_registered = []() { \
|
||||
::mscclpp::test::TestRegistry::instance().registerTest( \
|
||||
#test_fixture, #test_name, \
|
||||
#define TEST_F(test_fixture, test_name) \
|
||||
class test_fixture##_##test_name##_Test : public test_fixture { \
|
||||
public: \
|
||||
test_fixture##_##test_name##_Test() {} \
|
||||
void TestBody() override; \
|
||||
}; \
|
||||
static bool test_fixture##_##test_name##_registered = []() { \
|
||||
::mscclpp::test::TestRegistry::instance().registerTest( \
|
||||
#test_fixture, #test_name, \
|
||||
[]() -> ::mscclpp::test::TestCase* { return new test_fixture##_##test_name##_Test(); }); \
|
||||
return true; \
|
||||
}(); \
|
||||
return true; \
|
||||
}(); \
|
||||
void test_fixture##_##test_name##_Test::TestBody()
|
||||
|
||||
// Test runner macro
|
||||
#define RUN_ALL_TESTS() ::mscclpp::test::TestRegistry::instance().runAllTests(argc, argv)
|
||||
|
||||
// Assertion macros
|
||||
#define EXPECT_TRUE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
|
||||
"Expected: " #condition " to be true"); \
|
||||
} \
|
||||
#define EXPECT_TRUE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be true"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_FALSE(condition) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
|
||||
"Expected: " #condition " to be false"); \
|
||||
} \
|
||||
#define EXPECT_FALSE(condition) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be false"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_EQ(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 == v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define EXPECT_EQ(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 == v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_NE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 != v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define EXPECT_NE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 != v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_LT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 < v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
#define EXPECT_LT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 < v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_LE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 <= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define EXPECT_LE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 <= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_GT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 > v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
#define EXPECT_GT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 > v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_GE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 >= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define EXPECT_GE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 >= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_TRUE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
|
||||
"Expected: " #condition " to be true"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
#define ASSERT_TRUE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be true"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_FALSE(condition) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
|
||||
"Expected: " #condition " to be false"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
#define ASSERT_FALSE(condition) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Expected: " #condition " to be false"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_EQ(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 == v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define ASSERT_EQ(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 == v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_NE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 != v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define ASSERT_NE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 != v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_LT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 < v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
#define ASSERT_LT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 < v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_LE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 <= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define ASSERT_LE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 <= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_GT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 > v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
#define ASSERT_GT(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 > v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_GE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 >= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
#define ASSERT_GE(val1, val2) \
|
||||
do { \
|
||||
auto v1 = (val1); \
|
||||
auto v2 = (val2); \
|
||||
if (!(v1 >= v2)) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_NO_THROW(statement) \
|
||||
do { \
|
||||
try { \
|
||||
statement; \
|
||||
} catch (const std::exception& e) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #statement " not to throw\n Actual: threw " << e.what(); \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} catch (...) { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
|
||||
"Expected: " #statement " not to throw\n Actual: threw unknown exception"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
#define ASSERT_NO_THROW(statement) \
|
||||
do { \
|
||||
try { \
|
||||
statement; \
|
||||
} catch (const std::exception& e) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Expected: " #statement " not to throw\n Actual: threw " << e.what(); \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} catch (...) { \
|
||||
::mscclpp::test::utils::reportFailure( \
|
||||
__FILE__, __LINE__, "Expected: " #statement " not to throw\n Actual: threw unknown exception"); \
|
||||
throw std::runtime_error("Test assertion failed"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define FAIL() \
|
||||
do { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \
|
||||
throw std::runtime_error("Test failed"); \
|
||||
#define FAIL() \
|
||||
do { \
|
||||
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \
|
||||
throw std::runtime_error("Test failed"); \
|
||||
} while (0)
|
||||
|
||||
// Helper class for GTEST_SKIP functionality
|
||||
class SkipHelper {
|
||||
public:
|
||||
explicit SkipHelper(const char* file, int line) : file_(file), line_(line) {}
|
||||
template <typename T>
|
||||
SkipHelper& operator<<(const T& value) {
|
||||
message_ << value;
|
||||
return *this;
|
||||
}
|
||||
~SkipHelper() noexcept(false) {
|
||||
std::string msg = message_.str();
|
||||
if (!msg.empty()) {
|
||||
::mscclpp::test::utils::reportFailure(file_, line_, "Test skipped: " + msg);
|
||||
} else {
|
||||
::mscclpp::test::utils::reportFailure(file_, line_, "Test skipped");
|
||||
}
|
||||
throw std::runtime_error("Test skipped");
|
||||
}
|
||||
|
||||
private:
|
||||
const char* file_;
|
||||
int line_;
|
||||
std::ostringstream message_;
|
||||
};
|
||||
|
||||
#define GTEST_SKIP() ::SkipHelper(__FILE__, __LINE__)
|
||||
|
||||
// Create a namespace alias for compatibility with GTest code
|
||||
namespace testing = ::mscclpp::test;
|
||||
|
||||
// Helper functions for compatibility with GTest API
|
||||
inline void InitGoogleTest(int* argc, char** argv) {
|
||||
::mscclpp::test::TestRegistry::instance().initGoogleTest(argc, argv);
|
||||
}
|
||||
|
||||
inline ::mscclpp::test::Environment* AddGlobalTestEnvironment(::mscclpp::test::Environment* env) {
|
||||
::mscclpp::test::TestRegistry::instance().addGlobalTestEnvironment(env);
|
||||
return env;
|
||||
}
|
||||
|
||||
#endif // MSCCLPP_TEST_FRAMEWORK_HPP_
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
#ifndef MSCCLPP_MP_UNIT_TESTS_HPP_
|
||||
#define MSCCLPP_MP_UNIT_TESTS_HPP_
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
@@ -13,6 +11,7 @@
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace mscclpp {
|
||||
namespace test {
|
||||
|
||||
// Global state for performance test results
|
||||
static std::vector<struct PerfTestResult {
|
||||
static std::vector < struct PerfTestResult {
|
||||
std::string test_name;
|
||||
std::string test_category;
|
||||
std::map<std::string, std::string> test_params;
|
||||
@@ -20,7 +20,7 @@ static std::vector<struct PerfTestResult {
|
||||
int num_processes;
|
||||
int process_rank;
|
||||
std::string timestamp;
|
||||
}> g_perf_results;
|
||||
} > g_perf_results;
|
||||
|
||||
static std::string getCurrentTimestamp() {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
// This file is kept for backwards compatibility with perf tests
|
||||
// The actual framework is now in test/framework.hpp
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace test {
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
class LocalCommunicatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
TEST(ErrorsTest, SystemError) {
|
||||
mscclpp::Error error("test", mscclpp::ErrorCode::SystemError);
|
||||
EXPECT_EQ(error.getErrorCode(), mscclpp::ErrorCode::SystemError);
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
#define ITER 10000 // should be larger than the FIFO size for proper testing
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
TEST(GpuUtilsTest, StreamPool) {
|
||||
auto streamPool = mscclpp::gpuStreamPool();
|
||||
cudaStream_t s;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#define MAGIC_CONST 777
|
||||
|
||||
__constant__ mscclpp::PortChannelDeviceHandle gPortChannel;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
TEST(NumaTest, Basic) {
|
||||
int num;
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceCount(&num));
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <thread>
|
||||
|
||||
#include "../framework.hpp"
|
||||
#include "socket.h"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "../framework.hpp"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
TEST(UtilsInternalTest, getHostHash) {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <thread>
|
||||
|
||||
#include "../framework.hpp"
|
||||
|
||||
TEST(UtilsTest, getHostName) {
|
||||
std::string hostname1 = mscclpp::getHostName(1024, '.');
|
||||
EXPECT_FALSE(hostname1.empty());
|
||||
|
||||
Reference in New Issue
Block a user