mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Reorganize current native algorithm implementation and DSL algorithm implementation. Provide unified API for DSL algo and native algo and provide interface to tune the algo Provide interface for pytorch integration with native API and DSL --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
56 lines
1.7 KiB
C++
56 lines
1.7 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT license.
|
|
|
|
#include <numa.h>
|
|
|
|
#include <fstream>
|
|
#include <mscclpp/gpu_utils.hpp>
|
|
|
|
#include "api.h"
|
|
|
|
// Convert a logical deviceId index to the NVML device minor number
|
|
static const std::string getBusId(int deviceId) {
|
|
// On most systems, the PCI bus ID comes back as in the 0000:00:00.0
|
|
// format. Still need to allocate proper space in case PCI domain goes
|
|
// higher.
|
|
char busIdChar[] = "00000000:00:00.0";
|
|
MSCCLPP_CUDATHROW(cudaDeviceGetPCIBusId(busIdChar, sizeof(busIdChar), deviceId));
|
|
// we need the hex in lower case format
|
|
for (size_t i = 0; i < sizeof(busIdChar); i++) {
|
|
busIdChar[i] = std::tolower(busIdChar[i]);
|
|
}
|
|
return std::string(busIdChar);
|
|
}
|
|
|
|
namespace mscclpp {
|
|
|
|
MSCCLPP_API_CPP int getDeviceNumaNode(int deviceId) {
|
|
std::string busId = getBusId(deviceId);
|
|
std::string file_str = "/sys/bus/pci/devices/" + busId + "/numa_node";
|
|
std::ifstream file(file_str);
|
|
int numaNode;
|
|
if (file.is_open()) {
|
|
if (!(file >> numaNode)) {
|
|
throw Error("Failed to read NUMA node from file: " + file_str, ErrorCode::SystemError);
|
|
}
|
|
} else {
|
|
throw Error("Failed to open file: " + file_str, ErrorCode::SystemError);
|
|
}
|
|
return numaNode;
|
|
}
|
|
|
|
MSCCLPP_API_CPP void numaBind(int node) {
|
|
int totalNumNumaNodes = numa_num_configured_nodes();
|
|
if (node < 0 || node >= totalNumNumaNodes) {
|
|
throw Error(
|
|
"Invalid NUMA node " + std::to_string(node) + ", must be between 0 and " + std::to_string(totalNumNumaNodes),
|
|
ErrorCode::InvalidUsage);
|
|
}
|
|
nodemask_t mask;
|
|
nodemask_zero(&mask);
|
|
nodemask_set_compat(&mask, node);
|
|
numa_bind_compat(&mask);
|
|
}
|
|
|
|
} // namespace mscclpp
|