Merge pull request #32 from microsoft/chhwang/log-handler

Add mscclppSetLogHandler
This commit is contained in:
Saeed Maleki
2023-03-27 17:59:00 -07:00
committed by GitHub
5 changed files with 64 additions and 5 deletions

View File

@@ -17,11 +17,17 @@ thread_local int mscclppDebugNoWarn = 0;
char mscclppLastError[1024] = ""; // Global string for the last error in human readable form
uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT
FILE* mscclppDebugFile = stdout;
mscclppLogHandler_t mscclppDebugLogHandler = NULL;
pthread_mutex_t mscclppDebugLock = PTHREAD_MUTEX_INITIALIZER;
std::chrono::steady_clock::time_point mscclppEpoch;
static __thread int tid = -1;
void mscclppDebugDefaultLogHandler(const char* msg)
{
fwrite(msg, 1, strlen(msg), mscclppDebugFile);
}
void mscclppDebugInit()
{
pthread_mutex_lock(&mscclppDebugLock);
@@ -139,6 +145,9 @@ void mscclppDebugInit()
}
}
if (mscclppDebugLogHandler == NULL)
mscclppDebugLogHandler = mscclppDefaultLogHandler;
mscclppEpoch = std::chrono::steady_clock::now();
__atomic_store_n(&mscclppDebugLevel, tempNcclDebugLevel, __ATOMIC_RELEASE);
pthread_mutex_unlock(&mscclppDebugLock);
@@ -157,7 +166,6 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char
level = MSCCLPP_LOG_INFO;
flags = mscclppDebugNoWarn;
}
// Save the last error (WARN) as a human readable string
if (level == MSCCLPP_LOG_WARN) {
pthread_mutex_lock(&mscclppDebugLock);
@@ -182,7 +190,7 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char
char buffer[1024];
size_t len = 0;
if (level == MSCCLPP_LOG_WARN) {
len = snprintf(buffer, sizeof(buffer), "\n%s:%d:%d [%d] %s:%d MSCCLPP WARN ", hostname, pid, tid, cudaDev, filefunc,
len = snprintf(buffer, sizeof(buffer), "%s:%d:%d [%d] %s:%d MSCCLPP WARN ", hostname, pid, tid, cudaDev, filefunc,
line);
} else if (level == MSCCLPP_LOG_INFO) {
len = snprintf(buffer, sizeof(buffer), "%s:%d:%d [%d] MSCCLPP INFO ", hostname, pid, tid, cudaDev);
@@ -201,10 +209,22 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
va_end(vargs);
buffer[len++] = '\n';
fwrite(buffer, 1, len, mscclppDebugFile);
mscclppDebugLogHandler(buffer);
}
}
mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler)
{
if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1)
mscclppDebugInit();
if (handler == NULL)
return mscclppInvalidArgument;
pthread_mutex_lock(&mscclppDebugLock);
mscclppDebugLogHandler = handler;
pthread_mutex_unlock(&mscclppDebugLock);
return mscclppSuccess;
}
MSCCLPP_PARAM(SetThreadName, "SET_THREAD_NAME", 0);
void mscclppSetThreadName(pthread_t thread, const char* fmt, ...)

View File

@@ -49,8 +49,10 @@ extern pthread_mutex_t mscclppDebugLock;
extern FILE* mscclppDebugFile;
extern mscclppResult_t getHostName(char* hostname, int maxlen, const char delim);
void mscclppDebugDefaultLogHandler(const char* msg);
void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char* filefunc, int line, const char* fmt,
...) __attribute__((format(printf, 5, 6)));
mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler);
// Let code temporarily downgrade WARN into INFO
extern thread_local int mscclppDebugNoWarn;

View File

@@ -313,6 +313,26 @@ mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank);
*/
mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size);
/* Log handler type which is a callback function for
* however user likes to handle the log messages. Once set,
* the logger will just call this function with msg.
*/
typedef void (*mscclppLogHandler_t)(const char* msg);
/* The default log handler.
*
* Inputs:
* msg: the log message
*/
void mscclppDefaultLogHandler(const char* msg);
/* Set a custom log handler.
*
* Inputs:
* handler: the log handler function
*/
mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler);
#ifdef __cplusplus
} // end extern "C"
#endif

View File

@@ -594,4 +594,16 @@ mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
}
*size = comm->nRanks;
return mscclppSuccess;
}
}
MSCCLPP_API(void, mscclppDefaultLogHandler, const char* msg);
void mscclppDefaultLogHandler(const char* msg)
{
mscclppDebugDefaultLogHandler(msg);
}
MSCCLPP_API(mscclppResult_t, mscclppSetLogHandler, mscclppLogHandler_t handler);
mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler)
{
return mscclppDebugSetLogHandler(handler);
}

View File

@@ -34,6 +34,11 @@ void print_usage(const char* prog)
#endif
}
void myLogHandler(const char* msg)
{
printf("myLogger: %s", msg);
}
int main(int argc, const char* argv[])
{
if (argc >= 2 && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help")) {
@@ -64,8 +69,8 @@ int main(int argc, const char* argv[])
world_size = atoi(argv[3]);
#endif
MSCCLPPCHECK(mscclppSetLogHandler(myLogHandler));
mscclppComm_t comm;
if (ip_port) {
MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, ip_port, rank));
} else {