NUMA binding

This commit is contained in:
Changho Hwang
2023-02-23 08:18:12 +00:00
parent 1a528a3aa3
commit 29a430e7a8
5 changed files with 100 additions and 9 deletions

View File

@@ -101,7 +101,7 @@ LIBDIR := lib
OBJDIR := obj
BINDIR := bin
LDFLAGS := $(NVLDFLAGS) -libverbs -lgdrapi
LDFLAGS := $(NVLDFLAGS) -libverbs -lgdrapi -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc param.cc gdr.cc init.cc proxy.cc ib.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)

View File

@@ -10,6 +10,48 @@
#include "comm.h"
#include "ib.h"
static int getIbDevNumaNode(const char *ibDevPath)
{
if (ibDevPath == NULL) {
WARN("ibDevPath is NULL");
return -1;
}
const char *postfix = "/device/numa_node";
FILE *fp = NULL;
char *filePath = NULL;
int node = -1;
int res;
if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) {
WARN("mscclppCalloc failed");
goto exit;
}
memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char));
filePath[strlen(ibDevPath)] = '\0';
if (strncat(filePath, postfix, strlen(postfix)) == NULL) {
WARN("strncat failed");
goto exit;
}
fp = fopen(filePath, "r");
if (fp == NULL) {
WARN("fopen failed (errno %d, path %s)", errno, filePath);
goto exit;
}
res = fscanf(fp, "%d", &node);
if (res != 1) {
WARN("fscanf failed (errno %d, path %s)", errno, filePath);
node = -1;
goto exit;
}
exit:
if (filePath != NULL) {
free(filePath);
}
if (fp != NULL) {
fclose(fp);
}
return node;
}
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext **ctx, const char *ibDevName)
{
struct mscclppIbContext *_ctx;
@@ -18,10 +60,12 @@ mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext **ctx, const char
std::vector<int> ports;
int num;
const char *ibDevPath = NULL;
struct ibv_device **devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) {
_ctx->ctx = ibv_open_device(devices[i]);
ibDevPath = devices[i]->ibdev_path;
break;
}
}
@@ -31,6 +75,11 @@ mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext **ctx, const char
goto fail;
}
_ctx->numaNode = getIbDevNumaNode(ibDevPath);
if (_ctx->numaNode < 0) {
goto fail;
}
// Check available ports
struct ibv_device_attr devAttr;
if (ibv_query_device(_ctx->ctx, &devAttr) != 0) {

View File

@@ -56,7 +56,7 @@ struct mscclppIbQp {
// Holds resources of a single IB device.
struct mscclppIbContext {
int numa_node;
int numaNode;
struct ibv_context *ctx;
struct ibv_pd *pd;
int *ports;

View File

@@ -6,11 +6,20 @@
#include "checks.h"
#include <sys/syscall.h>
#include <numa.h>
#include <map>
#include <thread>
#define MSCCLPP_PROXY_FLAG_SET_BY_RDMA 0
static void NumaBind(int node)
{
nodemask_t mask;
nodemask_zero(&mask);
nodemask_set_compat(&mask, node);
numa_bind_compat(&mask);
}
struct proxyArgs {
struct mscclppComm* comm;
struct mscclppIbContext* ibCtx;
@@ -21,6 +30,7 @@ struct proxyArgs {
void* mscclppProxyService(void* _args) {
struct proxyArgs *args = (struct proxyArgs *)_args;
struct mscclppComm *comm = args->comm;
struct mscclppIbContext *ibCtx = args->ibCtx;
volatile int *run = args->run;
struct mscclppConn *conn = &comm->conns[args->connIdx];
free(_args);
@@ -28,7 +38,8 @@ void* mscclppProxyService(void* _args) {
int currentRemoteFlagVlaue = *conn->cpuRemoteFlag;
#endif
// TODO(chhwang): NUMA & core binding
// TODO(chhwang): core binding
NumaBind(ibCtx->numaNode);
enum {
SEND_STATE_INIT,

View File

@@ -46,12 +46,15 @@ __global__ void kernel(int rank, int world_size)
if (threadIdx.x == 0) {
// Set my data and flag
*(data + rank) = rank + 1;
// Do we need a sys fence?
__threadfence_system();
*localFlag = baseFlag + 1;
}
__syncthreads();
if (threadIdx.x == 0) {
// Do we need a sys fence?
// __threadfence_system();
*localFlag = baseFlag + 1;
}
// Each warp receives data from different ranks
if (devConn.remoteBuff == NULL) { // IB
// Trigger sending data and flag
@@ -86,6 +89,32 @@ int rankToNode(int rank)
return rank / RANKS_PER_NODE;
}
int cudaNumToIbNum(int cudaNum)
{
int ibNum;
if (cudaNum == 0) {
ibNum = 0;
} else if (cudaNum == 1) {
ibNum = 4;
} else if (cudaNum == 2) {
ibNum = 1;
} else if (cudaNum == 3) {
ibNum = 5;
} else if (cudaNum == 4) {
ibNum = 2;
} else if (cudaNum == 5) {
ibNum = 6;
} else if (cudaNum == 6) {
ibNum = 3;
} else if (cudaNum == 7) {
ibNum = 7;
} else {
printf("Invalid cudaNum: %d\n", cudaNum);
exit(EXIT_FAILURE);
}
return ibNum;
}
void print_usage(const char *prog)
{
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
@@ -124,7 +153,11 @@ int main(int argc, const char *argv[])
#endif
int localRank = rankToLocalRank(rank);
int thisNode = rankToNode(rank);
CUDACHECK(cudaSetDevice(localRank));
int cudaNum = localRank;
int ibNum = cudaNumToIbNum(cudaNum);
CUDACHECK(cudaSetDevice(cudaNum));
std::string ibDevStr = "mlx5_ib" + std::to_string(ibNum);
mscclppComm_t comm;
MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port));
@@ -137,8 +170,6 @@ int main(int argc, const char *argv[])
CUDACHECK(cudaMemset(data_d, 0, data_size));
CUDACHECK(cudaMemset(flag_d, 0, sizeof(int)));
std::string ibDevStr = "mlx5_ib" + std::to_string(localRank);
mscclppDevConn_t devConns[16];
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;