diff --git a/src/debug.cc b/src/debug.cc index 86e91409..59190680 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -17,11 +17,17 @@ thread_local int mscclppDebugNoWarn = 0; char mscclppLastError[1024] = ""; // Global string for the last error in human readable form uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT FILE* mscclppDebugFile = stdout; +mscclppLogHandler_t mscclppDebugLogHandler = NULL; pthread_mutex_t mscclppDebugLock = PTHREAD_MUTEX_INITIALIZER; std::chrono::steady_clock::time_point mscclppEpoch; static __thread int tid = -1; +void mscclppDebugDefaultLogHandler(const char* msg) +{ + fwrite(msg, 1, strlen(msg), mscclppDebugFile); +} + void mscclppDebugInit() { pthread_mutex_lock(&mscclppDebugLock); @@ -139,6 +145,9 @@ void mscclppDebugInit() } } + if (mscclppDebugLogHandler == NULL) + mscclppDebugLogHandler = mscclppDefaultLogHandler; + mscclppEpoch = std::chrono::steady_clock::now(); __atomic_store_n(&mscclppDebugLevel, tempNcclDebugLevel, __ATOMIC_RELEASE); pthread_mutex_unlock(&mscclppDebugLock); @@ -157,7 +166,6 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char level = MSCCLPP_LOG_INFO; flags = mscclppDebugNoWarn; } - // Save the last error (WARN) as a human readable string if (level == MSCCLPP_LOG_WARN) { pthread_mutex_lock(&mscclppDebugLock); @@ -182,7 +190,7 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char char buffer[1024]; size_t len = 0; if (level == MSCCLPP_LOG_WARN) { - len = snprintf(buffer, sizeof(buffer), "\n%s:%d:%d [%d] %s:%d MSCCLPP WARN ", hostname, pid, tid, cudaDev, filefunc, + len = snprintf(buffer, sizeof(buffer), "%s:%d:%d [%d] %s:%d MSCCLPP WARN ", hostname, pid, tid, cudaDev, filefunc, line); } else if (level == MSCCLPP_LOG_INFO) { len = snprintf(buffer, sizeof(buffer), "%s:%d:%d [%d] MSCCLPP INFO ", hostname, pid, tid, cudaDev); @@ -201,10 +209,22 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs); va_end(vargs); buffer[len++] = '\n'; - fwrite(buffer, 1, len, mscclppDebugFile); + mscclppDebugLogHandler(buffer); } } +mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler) +{ + if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) + mscclppDebugInit(); + if (handler == NULL) + return mscclppInvalidArgument; + pthread_mutex_lock(&mscclppDebugLock); + mscclppDebugLogHandler = handler; + pthread_mutex_unlock(&mscclppDebugLock); + return mscclppSuccess; +} + MSCCLPP_PARAM(SetThreadName, "SET_THREAD_NAME", 0); void mscclppSetThreadName(pthread_t thread, const char* fmt, ...) diff --git a/src/include/debug.h b/src/include/debug.h index 95da5bd5..dd548cbb 100644 --- a/src/include/debug.h +++ b/src/include/debug.h @@ -49,8 +49,10 @@ extern pthread_mutex_t mscclppDebugLock; extern FILE* mscclppDebugFile; extern mscclppResult_t getHostName(char* hostname, int maxlen, const char delim); +void mscclppDebugDefaultLogHandler(const char* msg); void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char* filefunc, int line, const char* fmt, ...) __attribute__((format(printf, 5, 6))); +mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler); // Let code temporarily downgrade WARN into INFO extern thread_local int mscclppDebugNoWarn; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 831bab75..24fb4063 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -313,6 +313,26 @@ mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank); */ mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size); +/* Log handler type which is a callback function for + * however user likes to handle the log messages. Once set, + * the logger will just call this function with msg. + */ +typedef void (*mscclppLogHandler_t)(const char* msg); + +/* The default log handler. + * + * Inputs: + * msg: the log message + */ +void mscclppDefaultLogHandler(const char* msg); + +/* Set a custom log handler. + * + * Inputs: + * handler: the log handler function + */ +mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler); + #ifdef __cplusplus } // end extern "C" #endif diff --git a/src/init.cc b/src/init.cc index 09015b9d..3845ce04 100644 --- a/src/init.cc +++ b/src/init.cc @@ -594,4 +594,16 @@ mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) } *size = comm->nRanks; return mscclppSuccess; -} \ No newline at end of file +} + +MSCCLPP_API(void, mscclppDefaultLogHandler, const char* msg); +void mscclppDefaultLogHandler(const char* msg) +{ + mscclppDebugDefaultLogHandler(msg); +} + +MSCCLPP_API(mscclppResult_t, mscclppSetLogHandler, mscclppLogHandler_t handler); +mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) +{ + return mscclppDebugSetLogHandler(handler); +} diff --git a/tests/bootstrap_test.cc b/tests/bootstrap_test.cc index b9d5ea84..0715d24f 100644 --- a/tests/bootstrap_test.cc +++ b/tests/bootstrap_test.cc @@ -34,6 +34,11 @@ void print_usage(const char* prog) #endif } +void myLogHandler(const char* msg) +{ + printf("myLogger: %s", msg); +} + int main(int argc, const char* argv[]) { if (argc >= 2 && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help")) { @@ -64,8 +69,8 @@ int main(int argc, const char* argv[]) world_size = atoi(argv[3]); #endif + MSCCLPPCHECK(mscclppSetLogHandler(myLogHandler)); mscclppComm_t comm; - if (ip_port) { MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, ip_port, rank)); } else {