Files
mscclpp/test/framework.hpp
2026-02-11 00:17:18 +00:00

337 lines
19 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_TEST_FRAMEWORK_HPP_
#define MSCCLPP_TEST_FRAMEWORK_HPP_
#include <mpi.h>
#include <chrono>
#include <fstream>
#include <functional>
#include <iostream>
#include <map>
#include <mscclpp/gpu.hpp>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
namespace mscclpp {
namespace test {
// Test result structure
struct TestResult {
std::string test_name;
std::string test_category;
std::map<std::string, std::string> test_params;
int num_processes;
int process_rank;
std::string timestamp;
bool passed;
std::string failure_message;
};
// Test case base class
class TestCase {
public:
virtual ~TestCase() = default;
virtual void SetUp() {}
virtual void TearDown() {}
virtual void TestBody() = 0;
};
// Test registry and runner
class TestRegistry {
public:
using TestFactory = std::function<TestCase*()>;
static TestRegistry& instance();
void registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory);
int runAllTests(int argc, char* argv[]);
private:
TestRegistry() = default;
struct TestInfo {
std::string suite_name;
std::string test_name;
TestFactory factory;
};
std::vector<TestInfo> tests_;
};
// Simple utility functions for testing
namespace utils {
// Test execution utilities (for performance tests)
int runMultipleTests(
int argc, char* argv[],
const std::vector<std::tuple<std::string, std::string, std::function<void(int, int, int)>>>& tests);
// MPI management
void initializeMPI(int argc, char* argv[]);
void cleanupMPI();
bool isMainRank();
int getMPIRank();
int getMPISize();
// Timing utilities
class Timer {
public:
Timer();
void start();
void stop();
double elapsedMicroseconds() const;
double elapsedMilliseconds() const;
double elapsedSeconds() const;
private:
std::chrono::high_resolution_clock::time_point start_time_;
std::chrono::high_resolution_clock::time_point end_time_;
bool is_running_;
};
// CUDA utilities
void cudaCheck(cudaError_t err, const char* file, int line);
#define CUDA_CHECK(call) mscclpp::test::utils::cudaCheck(call, __FILE__, __LINE__)
// Test assertion helpers
void reportFailure(const char* file, int line, const std::string& message);
void reportSuccess();
} // namespace utils
} // namespace test
} // namespace mscclpp
// Test registration macros
#define TEST(test_suite, test_name) \
class test_suite##_##test_name##_Test : public ::mscclpp::test::TestCase { \
public: \
test_suite##_##test_name##_Test() {} \
void TestBody() override; \
}; \
static bool test_suite##_##test_name##_registered = []() { \
::mscclpp::test::TestRegistry::instance().registerTest( \
#test_suite, #test_name, \
[]() -> ::mscclpp::test::TestCase* { return new test_suite##_##test_name##_Test(); }); \
return true; \
}(); \
void test_suite##_##test_name##_Test::TestBody()
#define TEST_F(test_fixture, test_name) \
class test_fixture##_##test_name##_Test : public test_fixture { \
public: \
test_fixture##_##test_name##_Test() {} \
void TestBody() override; \
}; \
static bool test_fixture##_##test_name##_registered = []() { \
::mscclpp::test::TestRegistry::instance().registerTest( \
#test_fixture, #test_name, \
[]() -> ::mscclpp::test::TestCase* { return new test_fixture##_##test_name##_Test(); }); \
return true; \
}(); \
void test_fixture##_##test_name##_Test::TestBody()
// Test runner macro
#define RUN_ALL_TESTS() ::mscclpp::test::TestRegistry::instance().runAllTests(argc, argv)
// Assertion macros
#define EXPECT_TRUE(condition) \
do { \
if (!(condition)) { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
"Expected: " #condition " to be true"); \
} \
} while (0)
#define EXPECT_FALSE(condition) \
do { \
if (condition) { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
"Expected: " #condition " to be false"); \
} \
} while (0)
#define EXPECT_EQ(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 == v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define EXPECT_NE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 != v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define EXPECT_LT(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 < v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define EXPECT_LE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 <= v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define EXPECT_GT(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 > v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define EXPECT_GE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 >= v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
} \
} while (0)
#define ASSERT_TRUE(condition) \
do { \
if (!(condition)) { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
"Expected: " #condition " to be true"); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_FALSE(condition) \
do { \
if (condition) { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
"Expected: " #condition " to be false"); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_EQ(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 == v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " == " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_NE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 != v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " != " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_LT(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 < v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " < " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_LE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 <= v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " <= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_GT(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 > v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " > " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_GE(val1, val2) \
do { \
auto v1 = (val1); \
auto v2 = (val2); \
if (!(v1 >= v2)) { \
std::ostringstream oss; \
oss << "Expected: " #val1 " >= " #val2 << "\n Actual: " << v1 << " vs " << v2; \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define ASSERT_NO_THROW(statement) \
do { \
try { \
statement; \
} catch (const std::exception& e) { \
std::ostringstream oss; \
oss << "Expected: " #statement " not to throw\n Actual: threw " << e.what(); \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, oss.str()); \
throw std::runtime_error("Test assertion failed"); \
} catch (...) { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, \
"Expected: " #statement " not to throw\n Actual: threw unknown exception"); \
throw std::runtime_error("Test assertion failed"); \
} \
} while (0)
#define FAIL() \
do { \
::mscclpp::test::utils::reportFailure(__FILE__, __LINE__, "Test failed"); \
throw std::runtime_error("Test failed"); \
} while (0)
#endif // MSCCLPP_TEST_FRAMEWORK_HPP_