mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Merge apt-extension
This commit is contained in:
139
.clang-format
139
.clang-format
@@ -1,137 +1,2 @@
|
||||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: Microsoft
|
||||
AccessModifierOffset: -2
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveMacros: false
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlines: Right
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllArgumentsOnNextLine: true
|
||||
AllowAllConstructorInitializersOnNextLine: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: Never
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: None
|
||||
AllowShortLambdasOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: false
|
||||
AlwaysBreakTemplateDeclarations: MultiLine
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterCaseLabel: false
|
||||
AfterClass: true
|
||||
AfterControlStatement: false # true
|
||||
AfterEnum: true
|
||||
AfterFunction: true
|
||||
AfterNamespace: false # true
|
||||
AfterObjCDeclaration: true
|
||||
AfterStruct: true
|
||||
AfterUnion: false
|
||||
AfterExternBlock: true
|
||||
BeforeCatch: false # true
|
||||
BeforeElse: false # true
|
||||
IndentBraces: false
|
||||
SplitEmptyFunction: true
|
||||
SplitEmptyRecord: true
|
||||
SplitEmptyNamespace: true
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Custom
|
||||
BreakBeforeInheritanceComma: false
|
||||
BreakInheritanceList: BeforeColon
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakConstructorInitializers: BeforeColon
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 120
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
||||
ConstructorInitializerIndentWidth: 2
|
||||
ContinuationIndentWidth: 2
|
||||
Cpp11BracedListStyle: true
|
||||
DeriveLineEnding: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
ForEachMacros:
|
||||
- foreach
|
||||
- Q_FOREACH
|
||||
- BOOST_FOREACH
|
||||
IncludeBlocks: Preserve
|
||||
IncludeCategories:
|
||||
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
||||
Priority: 2
|
||||
SortPriority: 0
|
||||
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
|
||||
Priority: 3
|
||||
SortPriority: 0
|
||||
- Regex: '.*'
|
||||
Priority: 1
|
||||
SortPriority: 0
|
||||
IncludeIsMainRegex: '(Test)?$'
|
||||
IncludeIsMainSourceRegex: ''
|
||||
IndentCaseLabels: false
|
||||
IndentExternBlock: NoIndent
|
||||
IndentGotoLabels: true
|
||||
IndentPPDirectives: None
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: true
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBinPackProtocolList: Auto
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PenaltyBreakAssignment: 2
|
||||
PenaltyBreakBeforeFirstCallParameter: 19
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyBreakTemplateDeclaration: 10
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 1000
|
||||
PointerAlignment: Left
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SortUsingDeclarations: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceAfterLogicalNot: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCpp11BracedList: false
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceInEmptyBlock: false
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 1
|
||||
SpacesInAngles: false
|
||||
SpacesInConditionalStatement: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
SpaceBeforeSquareBrackets: false
|
||||
Standard: Latest
|
||||
StatementMacros:
|
||||
- Q_UNUSED
|
||||
- QT_REQUIRE_VERSION
|
||||
TabWidth: 2
|
||||
UseCRLF: false
|
||||
UseTab: Never
|
||||
...
|
||||
BasedOnStyle: Google
|
||||
ColumnLimit: 120
|
||||
|
||||
@@ -1,32 +1,48 @@
|
||||
cmake_minimum_required(VERSION 3.26)
|
||||
|
||||
project(mscclpp LANGUAGES CUDA CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules)
|
||||
option(ENABLE_TRACE "Enable tracing" OFF)
|
||||
option(USE_MPI_FOR_TESTS "Use MPI for tests" ON)
|
||||
option(USE_NPKIT "Use NPKIT" ON)
|
||||
option(ALLOW_GDRCOPY "Use GDRCopy, if available" OFF)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(IBVerbs REQUIRED)
|
||||
find_package(NUMA REQUIRED)
|
||||
find_package(GDRCopy)
|
||||
|
||||
option(USE_MPI_FOR_TESTS "Use MPI for tests" ON)
|
||||
if(USE_MPI_FOR_TESTS)
|
||||
find_package(MPI REQUIRED)
|
||||
add_definitions(-DMSCCLPP_USE_MPI_FOR_TESTS)
|
||||
endif()
|
||||
if(ALLOW_GDRCOPY)
|
||||
find_package(GDRCopy)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
include(CTest)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/b796f7d44681514f58a683a3a71ff17c94edb0c1.zip)
|
||||
FetchContent_MakeAvailable(googletest)
|
||||
include(GoogleTest)
|
||||
|
||||
set(CLANG_FORMAT_SOURCE_DIRS include src tests)
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake)
|
||||
|
||||
add_library(mscclpp SHARED)
|
||||
add_subdirectory(src) # This adds the srouces to the mscclpp target
|
||||
target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src/include)
|
||||
target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src/include)
|
||||
set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver)
|
||||
if(GDRCOPY_FOUND)
|
||||
if(ENABLE_TRACE)
|
||||
target_compile_definitions(mscclpp PRIVATE ENABLE_TRACE)
|
||||
endif()
|
||||
if(USE_NPKIT)
|
||||
target_compile_definitions(mscclpp PRIVATE ENABLE_NPKIT)
|
||||
endif()
|
||||
if(ALLOW_GDRCOPY AND GDRCOPY_FOUND)
|
||||
target_compile_definitions(mscclpp PRIVATE MSCCLPP_USE_GDRCOPY)
|
||||
target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy)
|
||||
endif()
|
||||
|
||||
add_subdirectory(tests)
|
||||
add_subdirectory(src) # This adds the sources to the mscclpp target
|
||||
add_subdirectory(test)
|
||||
|
||||
238
Makefile
238
Makefile
@@ -1,238 +0,0 @@
|
||||
######## VERSION
|
||||
MSCCLPP_MAJOR := 0
|
||||
MSCCLPP_MINOR := 1
|
||||
MSCCLPP_PATCH := 0
|
||||
|
||||
######## COMPILE OPTIONS
|
||||
DEBUG ?= 0
|
||||
VERBOSE ?= 1
|
||||
TRACE ?= 0
|
||||
NPKIT ?= 0
|
||||
GDRCOPY ?= 0
|
||||
USE_MPI_FOR_TESTS ?= 1
|
||||
|
||||
######## CUDA
|
||||
CUDA_HOME ?= /usr/local/cuda
|
||||
CUDA_LIB ?= $(CUDA_HOME)/lib64
|
||||
CUDA_INC ?= $(CUDA_HOME)/include
|
||||
NVCC = $(CUDA_HOME)/bin/nvcc
|
||||
CUDA_VERSION = $(strip $(shell which $(NVCC) >/dev/null && $(NVCC) --version | grep release | sed 's/.*release //' | sed 's/\,.*//'))
|
||||
CUDA_MAJOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 1)
|
||||
CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2)
|
||||
# You should define NVCC_GENCODE in your environment to the minimal set
|
||||
# of archs to reduce compile time.
|
||||
CUDA8_GENCODE = -gencode=arch=compute_50,code=sm_50 \
|
||||
-gencode=arch=compute_60,code=sm_60 \
|
||||
-gencode=arch=compute_61,code=sm_61
|
||||
CUDA9_GENCODE = -gencode=arch=compute_70,code=sm_70
|
||||
CUDA11_GENCODE = -gencode=arch=compute_80,code=sm_80
|
||||
CUDA12_GENCODE = -gencode=arch=compute_90,code=sm_90
|
||||
|
||||
CUDA8_PTX = -gencode=arch=compute_61,code=compute_61
|
||||
CUDA9_PTX = -gencode=arch=compute_70,code=compute_70
|
||||
CUDA11_PTX = -gencode=arch=compute_80,code=compute_80
|
||||
CUDA12_PTX = -gencode=arch=compute_90,code=compute_90
|
||||
|
||||
######## CXX/NVCC
|
||||
CXX := g++
|
||||
NVTX ?= 1
|
||||
|
||||
ifeq ($(shell test "0$(CUDA_MAJOR)" -eq 11 -a "0$(CUDA_MINOR)" -ge 8 -o "0$(CUDA_MAJOR)" -gt 11; echo $$?),0)
|
||||
# Include Hopper support if we're using CUDA11.8 or above
|
||||
NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA12_GENCODE) $(CUDA12_PTX)
|
||||
else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 11; echo $$?),0)
|
||||
NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_PTX)
|
||||
# Include Volta support if we're using CUDA9 or above
|
||||
else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 9; echo $$?),0)
|
||||
NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA9_PTX)
|
||||
else
|
||||
NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA8_PTX)
|
||||
endif
|
||||
$(info NVCC_GENCODE is ${NVCC_GENCODE})
|
||||
|
||||
CXXFLAGS := -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR) -fPIC -fvisibility=hidden \
|
||||
-Wall -Wno-unused-function -Wno-sign-compare -std=c++14 -Wvla \
|
||||
-I $(CUDA_INC) \
|
||||
$(CXXFLAGS)
|
||||
|
||||
ifneq ($(TRACE), 0)
|
||||
CXXFLAGS += -DENABLE_TRACE
|
||||
endif
|
||||
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all
|
||||
# Use addprefix so that we can specify more than one path
|
||||
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt -lcuda
|
||||
|
||||
ifeq ($(DEBUG), 0)
|
||||
NVCUFLAGS += -O3
|
||||
CXXFLAGS += -O3 -g
|
||||
else
|
||||
NVCUFLAGS += -O0 -G -g
|
||||
CXXFLAGS += -O0 -g -ggdb3
|
||||
endif
|
||||
|
||||
ifneq ($(VERBOSE), 0)
|
||||
NVCUFLAGS += -Xptxas -v -Xcompiler -Wall,-Wextra,-Wno-unused-parameter
|
||||
CXXFLAGS += -Wall -Wextra
|
||||
else
|
||||
.SILENT:
|
||||
endif
|
||||
|
||||
ifeq ($(NVTX), 0)
|
||||
CXXFLAGS += -DNVTX_DISABLE
|
||||
endif
|
||||
|
||||
#### MPI (only for test code)
|
||||
ifeq ($(USE_MPI_FOR_TESTS), 1)
|
||||
MPI_HOME ?= /usr/local/mpi
|
||||
MPI_INC := -I$(MPI_HOME)/include
|
||||
MPI_LDFLAGS := -L$(MPI_HOME)/lib -lmpi
|
||||
MPI_MACRO := -D MSCCLPP_USE_MPI_FOR_TESTS
|
||||
else
|
||||
MPI_HOME :=
|
||||
MPI_INC :=
|
||||
MPI_LDFLAGS :=
|
||||
MPI_MACRO :=
|
||||
endif
|
||||
|
||||
#### GDRCOPY
|
||||
ifeq ($(GDRCOPY), 1)
|
||||
GDRCOPY_LDFLAGS := -lgdrapi
|
||||
CXXFLAGS += -DMSCCLPP_USE_GDRCOPY
|
||||
NVCUFLAGS += -DMSCCLPP_USE_GDRCOPY
|
||||
else
|
||||
GDRCOPY_LDFLAGS :=
|
||||
endif
|
||||
|
||||
#### MSCCL++
|
||||
BUILDDIR ?= $(abspath ./build)
|
||||
INCDIR := include
|
||||
LIBDIR := lib
|
||||
OBJDIR := obj
|
||||
BINDIR := bin
|
||||
|
||||
ifneq ($(NPKIT), 0)
|
||||
CXXFLAGS += -DENABLE_NPKIT
|
||||
NVCUFLAGS += -DENABLE_NPKIT
|
||||
endif
|
||||
|
||||
LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
|
||||
|
||||
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
|
||||
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
|
||||
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc errors.cc)
|
||||
ifneq ($(NPKIT), 0)
|
||||
LIBSRCS += $(addprefix src/misc/,npkit.cc)
|
||||
endif
|
||||
ifeq ($(GDRCOPY), 1)
|
||||
LIBSRCS += $(addprefix src/,gdr.cc)
|
||||
endif
|
||||
LIBOBJS := $(patsubst %.cc,%.o,$(LIBSRCS))
|
||||
LIBOBJTARGETS := $(LIBOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
|
||||
HEADERS := $(wildcard src/include/*.h)
|
||||
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
|
||||
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
|
||||
|
||||
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp errors.hpp
|
||||
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
|
||||
|
||||
LIBNAME := libmscclpp.so
|
||||
LIBSONAME := $(LIBNAME).$(MSCCLPP_MAJOR)
|
||||
LIBTARGET := $(BUILDDIR)/$(LIBDIR)/$(LIBNAME).$(MSCCLPP_MAJOR).$(MSCCLPP_MINOR).$(MSCCLPP_PATCH)
|
||||
|
||||
UTDIR := tests/unittests
|
||||
UTSRCS := $(addprefix $(UTDIR)/,ib_test.cc)
|
||||
UTOBJS := $(patsubst %.cc,%.o,$(UTSRCS))
|
||||
UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu communicator_test_cpp.cu bootstrap_test_cpp.cc allgather_test_cpp.cu)
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
MSCLLPPTESTSOBJSDIR:= $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)
|
||||
MSCLLPPTESTBINFILESLIST := allgather_test allreduce_test sendrecv_test
|
||||
MSCLLPPTESTBINS := $(MSCLLPPTESTBINFILESLIST:%=$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf)
|
||||
|
||||
INCLUDE := -Isrc -Isrc/include
|
||||
|
||||
.PHONY: all build lib unittests tests mscclpp-test cpplint cpplint-autofix cpplint-file-autofix clean
|
||||
|
||||
all: build
|
||||
|
||||
build: lib tests
|
||||
ifeq ($(USE_MPI_FOR_TESTS), 1)
|
||||
build: lib tests mscclpp-test
|
||||
endif
|
||||
|
||||
lib: $(LIBOBJTARGETS) $(INCTARGETS) $(LIBTARGET)
|
||||
|
||||
unittests: $(UTBINS)
|
||||
|
||||
tests: unittests $(TESTSBINS)
|
||||
|
||||
mscclpp-test: $(LIBTARGET) $(MSCLLPPTESTBINS)
|
||||
|
||||
cpplint:
|
||||
clang-format-12 -style=file --verbose --Werror --dry-run $(CPPSOURCES)
|
||||
clang-format-12 --dry-run $(PYTHONCPPSOURCES)
|
||||
|
||||
cpplint-autofix:
|
||||
clang-format-12 -style=file --verbose --Werror -i $(CPPSOURCES)
|
||||
clang-format-12 -i $(PYTHONCPPSOURCES)
|
||||
|
||||
# Run cpplint on a single file, example: make cpplint-file-autofix INPUTFILE=src/bootstrap/bootstrap.cc
|
||||
cpplint-file-autofix:
|
||||
clang-format-12 -style=file --verbose --Werror -i $(INPUTFILE)
|
||||
|
||||
# Compile libobjs
|
||||
$(BUILDDIR)/$(OBJDIR)/%.o: %.cc $(HEADERS)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $<
|
||||
|
||||
# Compile utobjs
|
||||
$(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o: $(UTDIR)/%.cc $(HEADERS)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $<
|
||||
|
||||
$(BUILDDIR)/$(INCDIR)/%: src/$(INCDIR)/%
|
||||
@mkdir -p $(@D)
|
||||
cp $< $@
|
||||
|
||||
$(LIBTARGET): $(LIBOBJTARGETS)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) -shared -Wl,--no-as-needed -Wl,-soname,$(LIBSONAME) -o $@ $^ $(CXXFLAGS) $(LDFLAGS)
|
||||
ln -sf $(LIBTARGET) $(BUILDDIR)/$(LIBDIR)/$(LIBNAME)
|
||||
ln -sf $(LIBTARGET) $(BUILDDIR)/$(LIBDIR)/$(LIBSONAME)
|
||||
|
||||
# UT bins
|
||||
$(BUILDDIR)/$(BINDIR)/$(UTDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o $(LIBOBJTARGETS)
|
||||
@mkdir -p $(@D)
|
||||
$(NVCC) -o $@ $+ $(MPI_LDFLAGS) $(LDFLAGS)
|
||||
|
||||
# Compile .cc tests
|
||||
$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cc $(INCTARGETS)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -c $< $(MPI_MACRO)
|
||||
|
||||
# Compile .cu tests
|
||||
$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cu $(INCTARGETS)
|
||||
@mkdir -p $(@D)
|
||||
$(NVCC) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(NVCUFLAGS) $(INCLUDE) -c $< $(MPI_MACRO)
|
||||
|
||||
# Test bins
|
||||
$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o $(LIBTARGET)
|
||||
@mkdir -p $(@D)
|
||||
$(NVCC) -o $@ $< $(MPI_LDFLAGS) -L$(BUILDDIR)/$(LIBDIR) -lmscclpp
|
||||
|
||||
# Compile mscclpp_test
|
||||
$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf: $(MSCLLPPTESTSOBJSDIR)/%.o $(MSCLLPPTESTSOBJSDIR)/common.o
|
||||
@mkdir -p $(@D)
|
||||
$(NVCC) -o $@ $^ $(MPI_LDFLAGS) -L$(BUILDDIR)/$(LIBDIR) -lmscclpp
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILDDIR)
|
||||
18
cmake/AddClangFormatTargets.cmake
Normal file
18
cmake/AddClangFormatTargets.cmake
Normal file
@@ -0,0 +1,18 @@
|
||||
# Add targets to run clang-format
|
||||
|
||||
find_program(CLANG_FORMAT clang-format)
|
||||
if(CLANG_FORMAT)
|
||||
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
|
||||
set(CLANG_FORMAT_FILE_TYPES *.h *.hpp *.c *.cc *.cpp *.cu)
|
||||
# Produce combinations of source directories and file types
|
||||
foreach(SOURCE_DIR ${CLANG_FORMAT_SOURCE_DIRS})
|
||||
foreach(FILE_TYPE ${CLANG_FORMAT_FILE_TYPES})
|
||||
list(APPEND CLANG_FORMAT_SOURCE_PATTERNS ${SOURCE_DIR}/${FILE_TYPE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
file(GLOB_RECURSE CLANG_FORMAT_SOURCES ${CLANG_FORMAT_SOURCE_PATTERNS})
|
||||
add_custom_target(check-format ALL COMMAND ${CLANG_FORMAT} -style=file --dry-run ${CLANG_FORMAT_SOURCES})
|
||||
add_custom_target(format COMMAND ${CLANG_FORMAT} -style=file -i ${CLANG_FORMAT_SOURCES})
|
||||
else()
|
||||
message(STATUS "clang-format not found.")
|
||||
endif()
|
||||
@@ -1,34 +1,26 @@
|
||||
#ifndef MSCCLPP_CHANNEL_HPP_
|
||||
#define MSCCLPP_CHANNEL_HPP_
|
||||
|
||||
#include "epoch.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "utils.hpp"
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/epoch.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace channel {
|
||||
|
||||
// A Channel pairs a Connection with an Epoch
|
||||
class Channel
|
||||
{
|
||||
public:
|
||||
class Channel {
|
||||
public:
|
||||
Channel(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)){};
|
||||
: connection_(connection), epoch_(std::make_shared<DeviceEpoch>(communicator, connection)){};
|
||||
|
||||
Connection& connection()
|
||||
{
|
||||
return *connection_;
|
||||
}
|
||||
Epoch& epoch()
|
||||
{
|
||||
return *epoch_;
|
||||
}
|
||||
Connection& connection() { return *connection_; }
|
||||
DeviceEpoch& epoch() { return *epoch_; }
|
||||
|
||||
private:
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
std::shared_ptr<Epoch> epoch_;
|
||||
std::shared_ptr<DeviceEpoch> epoch_;
|
||||
};
|
||||
|
||||
using ChannelId = uint32_t;
|
||||
@@ -52,12 +44,11 @@ using MemoryId = uint32_t;
|
||||
// the summation of number of bits must be 128 or less
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
struct
|
||||
{
|
||||
struct {
|
||||
// first 64 bits: value[0]
|
||||
uint64_t size : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// second 64 bits: value[1]
|
||||
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
@@ -65,19 +56,14 @@ union ChannelTrigger {
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t chanId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE -
|
||||
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ ChannelTrigger()
|
||||
{
|
||||
}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value)
|
||||
{
|
||||
}
|
||||
__device__ ChannelTrigger() {}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size, int connectionId)
|
||||
{
|
||||
uint64_t size, int connectionId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
|
||||
<< MSCCLPP_BITS_REGMEM_HANDLE) +
|
||||
@@ -85,70 +71,61 @@ union ChannelTrigger {
|
||||
<< MSCCLPP_BITS_OFFSET) +
|
||||
dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct DeviceChannel
|
||||
{
|
||||
struct DeviceChannel {
|
||||
DeviceChannel() = default;
|
||||
|
||||
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo)
|
||||
: channelId_(channelId), epoch_(epoch), fifo_(fifo)
|
||||
{
|
||||
}
|
||||
DeviceChannel(ChannelId channelId, DeviceEpoch::DeviceHandle epoch, DeviceProxyFifo fifo)
|
||||
: channelId_(channelId), epoch_(epoch), fifo_(fifo) {}
|
||||
|
||||
DeviceChannel(const DeviceChannel& other) = default;
|
||||
|
||||
DeviceChannel& operator=(DeviceChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
__forceinline__ __device__ void signal() {
|
||||
epochIncrement();
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
uint64_t size) {
|
||||
epochIncrement();
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
uint64_t srcOffset, uint64_t size) {
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, channelId_).value);
|
||||
__forceinline__ __device__ void flush() {
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, channelId_).value);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
@@ -156,20 +133,14 @@ struct DeviceChannel
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
epoch_.wait();
|
||||
}
|
||||
__forceinline__ __device__ void wait() { epoch_.wait(); }
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
epoch_.epochIncrement();
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
__forceinline__ __device__ void epochIncrement() { epoch_.epochIncrement(); }
|
||||
#endif // __CUDACC__
|
||||
|
||||
ChannelId channelId_;
|
||||
|
||||
DeviceEpoch epoch_;
|
||||
DeviceEpoch::DeviceHandle epoch_;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
@@ -180,42 +151,29 @@ class DeviceChannelService;
|
||||
|
||||
inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService);
|
||||
|
||||
class DeviceChannelService
|
||||
{
|
||||
public:
|
||||
class DeviceChannelService {
|
||||
public:
|
||||
DeviceChannelService(Communicator& communicator);
|
||||
|
||||
ChannelId addChannel(std::shared_ptr<Connection> connection)
|
||||
{
|
||||
ChannelId addChannel(std::shared_ptr<Connection> connection) {
|
||||
channels_.push_back(Channel(communicator_, connection));
|
||||
return channels_.size() - 1;
|
||||
}
|
||||
|
||||
MemoryId addMemory(RegisteredMemory memory)
|
||||
{
|
||||
MemoryId addMemory(RegisteredMemory memory) {
|
||||
memories_.push_back(memory);
|
||||
return memories_.size() - 1;
|
||||
}
|
||||
|
||||
Channel channel(ChannelId id)
|
||||
{
|
||||
return channels_[id];
|
||||
}
|
||||
DeviceChannel deviceChannel(ChannelId id)
|
||||
{
|
||||
return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo());
|
||||
Channel channel(ChannelId id) { return channels_[id]; }
|
||||
DeviceChannel deviceChannel(ChannelId id) {
|
||||
return DeviceChannel(id, channels_[id].epoch().deviceHandle(), proxy_.fifo().deviceFifo());
|
||||
}
|
||||
|
||||
void startProxy()
|
||||
{
|
||||
proxy_.start();
|
||||
}
|
||||
void stopProxy()
|
||||
{
|
||||
proxy_.stop();
|
||||
}
|
||||
void startProxy() { proxy_.start(); }
|
||||
void stopProxy() { proxy_.stop(); }
|
||||
|
||||
private:
|
||||
private:
|
||||
Communicator& communicator_;
|
||||
std::vector<Channel> channels_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
@@ -224,8 +182,7 @@ private:
|
||||
|
||||
void bindThread();
|
||||
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw)
|
||||
{
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
Channel& channel = channels_[trigger->fields.chanId];
|
||||
|
||||
@@ -250,13 +207,10 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
struct SimpleDeviceChannel
|
||||
{
|
||||
struct SimpleDeviceChannel {
|
||||
SimpleDeviceChannel() = default;
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src)
|
||||
{
|
||||
}
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
|
||||
|
||||
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;
|
||||
|
||||
@@ -264,64 +218,42 @@ struct SimpleDeviceChannel
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
devChan_.put(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(offset, offset, size);
|
||||
}
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
devChan_.signal();
|
||||
}
|
||||
__forceinline__ __device__ void signal() { devChan_.signal(); }
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(offset, offset, size);
|
||||
}
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
devChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
devChan_.flush();
|
||||
}
|
||||
__forceinline__ __device__ void flush() { devChan_.flush(); }
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
devChan_.wait();
|
||||
}
|
||||
__forceinline__ __device__ void wait() { devChan_.wait(); }
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
devChan_.epochIncrement();
|
||||
}
|
||||
__forceinline__ __device__ void epochIncrement() { devChan_.epochIncrement(); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
#endif // __CUDACC__
|
||||
|
||||
DeviceChannel devChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
};
|
||||
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CHANNEL_HPP_
|
||||
#endif // MSCCLPP_CHANNEL_HPP_
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef MSCCLPP_HPP_
|
||||
#define MSCCLPP_HPP_
|
||||
#ifndef MSCCLPP_CORE_HPP_
|
||||
#define MSCCLPP_CORE_HPP_
|
||||
|
||||
#define MSCCLPP_MAJOR 0
|
||||
#define MSCCLPP_MINOR 1
|
||||
@@ -9,20 +9,19 @@
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
struct UniqueId
|
||||
{
|
||||
struct UniqueId {
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
};
|
||||
|
||||
class BaseBootstrap
|
||||
{
|
||||
public:
|
||||
class BaseBootstrap {
|
||||
public:
|
||||
BaseBootstrap(){};
|
||||
virtual ~BaseBootstrap() = default;
|
||||
virtual int getRank() = 0;
|
||||
@@ -33,14 +32,12 @@ public:
|
||||
virtual void barrier() = 0;
|
||||
|
||||
// TODO: move implementations of these helpers out of this header
|
||||
void send(const std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
void send(const std::vector<char>& data, int peer, int tag) {
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
void recv(std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
void recv(std::vector<char>& data, int peer, int tag) {
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
@@ -48,9 +45,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class Bootstrap : public BaseBootstrap
|
||||
{
|
||||
public:
|
||||
class Bootstrap : public BaseBootstrap {
|
||||
public:
|
||||
Bootstrap(int rank, int nRanks);
|
||||
~Bootstrap();
|
||||
|
||||
@@ -66,7 +62,7 @@ public:
|
||||
void allGather(void* allData, int size) override;
|
||||
void barrier() override;
|
||||
|
||||
private:
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
};
|
||||
@@ -80,147 +76,78 @@ private:
|
||||
*/
|
||||
std::unique_ptr<UniqueId> getUniqueId();
|
||||
|
||||
enum class Transport
|
||||
{
|
||||
Unknown,
|
||||
CudaIpc,
|
||||
IB0,
|
||||
IB1,
|
||||
IB2,
|
||||
IB3,
|
||||
IB4,
|
||||
IB5,
|
||||
IB6,
|
||||
IB7,
|
||||
NumTransports
|
||||
};
|
||||
enum class Transport { Unknown, CudaIpc, IB0, IB1, IB2, IB3, IB4, IB5, IB6, IB7, NumTransports };
|
||||
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 10;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
|
||||
"TransportFlagsSize must match the number of transports");
|
||||
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
|
||||
} // namespace detail
|
||||
} // namespace detail
|
||||
|
||||
class TransportFlags : private detail::TransportFlagsBase
|
||||
{
|
||||
public:
|
||||
class TransportFlags : private detail::TransportFlagsBase {
|
||||
public:
|
||||
TransportFlags() = default;
|
||||
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport))
|
||||
{
|
||||
}
|
||||
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport)) {}
|
||||
|
||||
bool has(Transport transport) const
|
||||
{
|
||||
return detail::TransportFlagsBase::test(static_cast<size_t>(transport));
|
||||
}
|
||||
bool has(Transport transport) const { return detail::TransportFlagsBase::test(static_cast<size_t>(transport)); }
|
||||
|
||||
bool none() const
|
||||
{
|
||||
return detail::TransportFlagsBase::none();
|
||||
}
|
||||
bool none() const { return detail::TransportFlagsBase::none(); }
|
||||
|
||||
bool any() const
|
||||
{
|
||||
return detail::TransportFlagsBase::any();
|
||||
}
|
||||
bool any() const { return detail::TransportFlagsBase::any(); }
|
||||
|
||||
bool all() const
|
||||
{
|
||||
return detail::TransportFlagsBase::all();
|
||||
}
|
||||
bool all() const { return detail::TransportFlagsBase::all(); }
|
||||
|
||||
size_t count() const
|
||||
{
|
||||
return detail::TransportFlagsBase::count();
|
||||
}
|
||||
size_t count() const { return detail::TransportFlagsBase::count(); }
|
||||
|
||||
TransportFlags& operator|=(TransportFlags other)
|
||||
{
|
||||
TransportFlags& operator|=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator|=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator|(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) |= other;
|
||||
}
|
||||
TransportFlags operator|(TransportFlags other) const { return TransportFlags(*this) |= other; }
|
||||
|
||||
TransportFlags operator|(Transport transport) const
|
||||
{
|
||||
return *this | TransportFlags(transport);
|
||||
}
|
||||
TransportFlags operator|(Transport transport) const { return *this | TransportFlags(transport); }
|
||||
|
||||
TransportFlags& operator&=(TransportFlags other)
|
||||
{
|
||||
TransportFlags& operator&=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator&=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator&(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) &= other;
|
||||
}
|
||||
TransportFlags operator&(TransportFlags other) const { return TransportFlags(*this) &= other; }
|
||||
|
||||
TransportFlags operator&(Transport transport) const
|
||||
{
|
||||
return *this & TransportFlags(transport);
|
||||
}
|
||||
TransportFlags operator&(Transport transport) const { return *this & TransportFlags(transport); }
|
||||
|
||||
TransportFlags& operator^=(TransportFlags other)
|
||||
{
|
||||
TransportFlags& operator^=(TransportFlags other) {
|
||||
detail::TransportFlagsBase::operator^=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator^(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) ^= other;
|
||||
}
|
||||
TransportFlags operator^(TransportFlags other) const { return TransportFlags(*this) ^= other; }
|
||||
|
||||
TransportFlags operator^(Transport transport) const
|
||||
{
|
||||
return *this ^ TransportFlags(transport);
|
||||
}
|
||||
TransportFlags operator^(Transport transport) const { return *this ^ TransportFlags(transport); }
|
||||
|
||||
TransportFlags operator~() const
|
||||
{
|
||||
return TransportFlags(*this).flip();
|
||||
}
|
||||
TransportFlags operator~() const { return TransportFlags(*this).flip(); }
|
||||
|
||||
bool operator==(TransportFlags other) const
|
||||
{
|
||||
return detail::TransportFlagsBase::operator==(other);
|
||||
}
|
||||
bool operator==(TransportFlags other) const { return detail::TransportFlagsBase::operator==(other); }
|
||||
|
||||
bool operator!=(TransportFlags other) const
|
||||
{
|
||||
return detail::TransportFlagsBase::operator!=(other);
|
||||
}
|
||||
bool operator!=(TransportFlags other) const { return detail::TransportFlagsBase::operator!=(other); }
|
||||
|
||||
detail::TransportFlagsBase toBitset() const
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
detail::TransportFlagsBase toBitset() const { return *this; }
|
||||
|
||||
private:
|
||||
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset)
|
||||
{
|
||||
}
|
||||
private:
|
||||
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {}
|
||||
};
|
||||
|
||||
inline TransportFlags operator|(Transport transport1, Transport transport2)
|
||||
{
|
||||
inline TransportFlags operator|(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) | transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator&(Transport transport1, Transport transport2)
|
||||
{
|
||||
inline TransportFlags operator&(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) & transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator^(Transport transport1, Transport transport2)
|
||||
{
|
||||
inline TransportFlags operator^(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
@@ -236,14 +163,13 @@ Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
class Communicator;
|
||||
class Connection;
|
||||
|
||||
class RegisteredMemory
|
||||
{
|
||||
class RegisteredMemory {
|
||||
struct Impl;
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
|
||||
// lazily.
|
||||
std::shared_ptr<Impl> pimpl;
|
||||
|
||||
public:
|
||||
public:
|
||||
RegisteredMemory() = default;
|
||||
RegisteredMemory(std::shared_ptr<Impl> pimpl);
|
||||
~RegisteredMemory();
|
||||
@@ -260,9 +186,8 @@ public:
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
class Connection
|
||||
{
|
||||
public:
|
||||
class Connection {
|
||||
public:
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) = 0;
|
||||
|
||||
@@ -276,47 +201,34 @@ public:
|
||||
|
||||
virtual Transport remoteTransport() = 0;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
|
||||
};
|
||||
|
||||
struct Setuppable
|
||||
{
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap>)
|
||||
{
|
||||
}
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>)
|
||||
{
|
||||
}
|
||||
struct Setuppable {
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
};
|
||||
|
||||
template <typename T> class NonblockingFuture
|
||||
{
|
||||
template <typename T>
|
||||
class NonblockingFuture {
|
||||
std::shared_future<T> future;
|
||||
|
||||
public:
|
||||
public:
|
||||
NonblockingFuture() = default;
|
||||
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future))
|
||||
{
|
||||
}
|
||||
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
|
||||
NonblockingFuture(const NonblockingFuture&) = default;
|
||||
|
||||
bool ready() const
|
||||
{
|
||||
return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
|
||||
}
|
||||
bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }
|
||||
|
||||
T get()
|
||||
{
|
||||
if (!ready())
|
||||
throw std::runtime_error("NonblockingFuture::get() called before ready");
|
||||
T get() {
|
||||
if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage);
|
||||
return future.get();
|
||||
}
|
||||
};
|
||||
|
||||
class Communicator
|
||||
{
|
||||
public:
|
||||
class Communicator {
|
||||
public:
|
||||
/* Initialize the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
@@ -360,26 +272,25 @@ public:
|
||||
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport);
|
||||
|
||||
/* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */
|
||||
void addSetup(std::shared_ptr<Setuppable> setuppable);
|
||||
void onSetup(std::shared_ptr<Setuppable> setuppable);
|
||||
|
||||
/* Setup all objects that have registered for setup. This includes any connections created by connect(). */
|
||||
void setup();
|
||||
|
||||
struct Impl;
|
||||
|
||||
private:
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<mscclpp::TransportFlags>
|
||||
{
|
||||
size_t operator()(const mscclpp::TransportFlags& flags) const
|
||||
{
|
||||
template <>
|
||||
struct hash<mscclpp::TransportFlags> {
|
||||
size_t operator()(const mscclpp::TransportFlags& flags) const {
|
||||
return hash<mscclpp::detail::TransportFlagsBase>()(flags.toBitset());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
} // namespace std
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
#endif // MSCCLPP_CORE_HPP_
|
||||
67
include/mscclpp/epoch.hpp
Normal file
67
include/mscclpp/epoch.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#ifndef MSCCLPP_EPOCH_HPP_
|
||||
#define MSCCLPP_EPOCH_HPP_
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct alignas(16) EpochIds {
|
||||
uint64_t outbound;
|
||||
uint64_t inboundReplica;
|
||||
};
|
||||
|
||||
class BaseEpoch {
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
RegisteredMemory localEpochIdsRegMem_;
|
||||
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
|
||||
|
||||
protected:
|
||||
EpochIds* epochIds_;
|
||||
uint64_t* expectedInboundEpochId_;
|
||||
|
||||
public:
|
||||
BaseEpoch(std::shared_ptr<Connection> connection);
|
||||
void setup(Communicator& communicator);
|
||||
BaseEpoch(const BaseEpoch&) = delete;
|
||||
void signal();
|
||||
};
|
||||
|
||||
class DeviceEpoch : BaseEpoch {
|
||||
public:
|
||||
DeviceEpoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
DeviceEpoch(const DeviceEpoch&) = delete;
|
||||
~DeviceEpoch();
|
||||
void signal();
|
||||
|
||||
struct DeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds->inboundReplica) < (*expectedInboundEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(epochIds->outbound) += 1; }
|
||||
#endif // __CUDACC__
|
||||
|
||||
EpochIds* epochIds;
|
||||
uint64_t* expectedInboundEpochId;
|
||||
};
|
||||
|
||||
DeviceHandle deviceHandle();
|
||||
};
|
||||
|
||||
class HostEpoch : BaseEpoch {
|
||||
public:
|
||||
HostEpoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
HostEpoch(const HostEpoch&) = delete;
|
||||
~HostEpoch();
|
||||
|
||||
void increamentAndSignal();
|
||||
void wait();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
@@ -4,43 +4,47 @@
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mscclpp {
|
||||
class BaseError : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
|
||||
enum class ErrorCode {
|
||||
SystemError,
|
||||
InternalError,
|
||||
InvalidUsage,
|
||||
};
|
||||
|
||||
class BaseError : public std::runtime_error {
|
||||
public:
|
||||
BaseError(std::string message, int errorCode);
|
||||
virtual ~BaseError() = default;
|
||||
int getErrorCode() const;
|
||||
|
||||
private:
|
||||
private:
|
||||
int errorCode_;
|
||||
};
|
||||
|
||||
class Error : public BaseError
|
||||
{
|
||||
public:
|
||||
Error(std::string message, int errorCode);
|
||||
class Error : public BaseError {
|
||||
public:
|
||||
Error(std::string message, ErrorCode errorCode);
|
||||
virtual ~Error() = default;
|
||||
};
|
||||
|
||||
class CudaError : public BaseError
|
||||
{
|
||||
public:
|
||||
class CudaError : public BaseError {
|
||||
public:
|
||||
CudaError(std::string message, int errorCode);
|
||||
virtual ~CudaError() = default;
|
||||
};
|
||||
|
||||
class CuError : public BaseError
|
||||
{
|
||||
public:
|
||||
class CuError : public BaseError {
|
||||
public:
|
||||
CuError(std::string message, int errorCode);
|
||||
virtual ~CuError() = default;
|
||||
};
|
||||
|
||||
class IbError : public BaseError
|
||||
{
|
||||
public:
|
||||
class IbError : public BaseError {
|
||||
public:
|
||||
IbError(std::string message, int errorCode);
|
||||
virtual ~IbError() = default;
|
||||
};
|
||||
}; // namespace mscclpp
|
||||
#endif // MSCCLPP_ERRORS_HPP
|
||||
|
||||
}; // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_ERRORS_HPP_
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef MSCCLPPFIFO_HPP_
|
||||
#define MSCCLPPFIFO_HPP_
|
||||
#ifndef MSCCLPP_FIFO_HPP_
|
||||
#define MSCCLPP_FIFO_HPP_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -12,8 +13,7 @@ namespace mscclpp {
|
||||
#define MSCCLPP_PROXY_FIFO_SIZE 128
|
||||
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
|
||||
|
||||
struct alignas(16) ProxyTrigger
|
||||
{
|
||||
struct alignas(16) ProxyTrigger {
|
||||
uint64_t fst, snd;
|
||||
};
|
||||
|
||||
@@ -30,11 +30,9 @@ struct alignas(16) ProxyTrigger
|
||||
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct DeviceProxyFifo
|
||||
{
|
||||
struct DeviceProxyFifo {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
|
||||
{
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica))
|
||||
;
|
||||
@@ -44,17 +42,16 @@ struct DeviceProxyFifo
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
#endif // __CUDACC__
|
||||
|
||||
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
};
|
||||
|
||||
class HostProxyFifo
|
||||
{
|
||||
public:
|
||||
class HostProxyFifo {
|
||||
public:
|
||||
HostProxyFifo();
|
||||
|
||||
~HostProxyFifo();
|
||||
@@ -67,11 +64,11 @@ public:
|
||||
|
||||
DeviceProxyFifo deviceFifo();
|
||||
|
||||
private:
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPPFIFO_H_
|
||||
#endif // MSCCLPP_FIFO_HPP_
|
||||
@@ -1,15 +1,13 @@
|
||||
#ifndef MSCCLPP_PROXY_HPP_
|
||||
#define MSCCLPP_PROXY_HPP_
|
||||
|
||||
#include "mscclppfifo.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
enum class ProxyHandlerResult
|
||||
{
|
||||
enum class ProxyHandlerResult {
|
||||
Continue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
@@ -18,9 +16,8 @@ enum class ProxyHandlerResult
|
||||
class Proxy;
|
||||
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy
|
||||
{
|
||||
public:
|
||||
class Proxy {
|
||||
public:
|
||||
Proxy(ProxyHandler handler, std::function<void()> threadInit);
|
||||
Proxy(ProxyHandler handler);
|
||||
~Proxy();
|
||||
@@ -30,11 +27,11 @@ public:
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
|
||||
private:
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_PROXY_HPP_
|
||||
#endif // MSCCLPP_PROXY_HPP_
|
||||
@@ -1,5 +1,2 @@
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.h)
|
||||
file(GLOB to_remove gdr.cc)
|
||||
list(REMOVE_ITEM SOURCES ${to_remove})
|
||||
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc)
|
||||
target_sources(mscclpp PRIVATE ${SOURCES})
|
||||
|
||||
@@ -1,36 +1,25 @@
|
||||
#include "bootstrap.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "utils.h"
|
||||
#include <sys/resource.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <sys/resource.h>
|
||||
#include <sys/types.h>
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "socket.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace mscclpp;
|
||||
|
||||
namespace {
|
||||
uint64_t hashUniqueId(const mscclppBootstrapHandle& id)
|
||||
{
|
||||
const char* bytes = (const char*)&id;
|
||||
uint64_t h = 0xdeadbeef;
|
||||
for (int i = 0; i < (int)sizeof(mscclppBootstrapHandle); i++) {
|
||||
h ^= h >> 32;
|
||||
h *= 0x8db3db47fa2994ad;
|
||||
h += bytes[i];
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
mscclppResult_t setFilesLimit()
|
||||
{
|
||||
mscclppResult_t setFilesLimit() {
|
||||
rlimit filesLimit;
|
||||
SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit");
|
||||
filesLimit.rlim_cur = filesLimit.rlim_max;
|
||||
@@ -38,40 +27,32 @@ mscclppResult_t setFilesLimit()
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
/* Socket Interface Selection type */
|
||||
enum bootstrapInterface_t
|
||||
{
|
||||
findSubnetIf = -1,
|
||||
dontCareIf = -2
|
||||
};
|
||||
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
|
||||
|
||||
struct UnexpectedMsg
|
||||
{
|
||||
struct UnexpectedMsg {
|
||||
int peer;
|
||||
int tag;
|
||||
std::shared_ptr<mscclppSocket> sock;
|
||||
};
|
||||
|
||||
struct ExtInfo
|
||||
{
|
||||
struct ExtInfo {
|
||||
int rank;
|
||||
int nRanks;
|
||||
mscclppSocketAddress extAddressListenRoot;
|
||||
mscclppSocketAddress extAddressListen;
|
||||
};
|
||||
|
||||
struct UniqueIdInternal
|
||||
{
|
||||
struct UniqueIdInternal {
|
||||
uint64_t magic;
|
||||
union mscclppSocketAddress addr;
|
||||
};
|
||||
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");
|
||||
|
||||
class Bootstrap::Impl
|
||||
{
|
||||
public:
|
||||
class Bootstrap::Impl {
|
||||
public:
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
void initialize(const UniqueId uniqueId);
|
||||
@@ -87,7 +68,7 @@ public:
|
||||
void barrier();
|
||||
void close();
|
||||
|
||||
private:
|
||||
private:
|
||||
UniqueIdInternal uniqueId_;
|
||||
int rank_;
|
||||
int nRanks_;
|
||||
@@ -118,20 +99,20 @@ private:
|
||||
// UniqueId MscclppBootstrap::Impl::uniqueId_;
|
||||
|
||||
Bootstrap::Impl::Impl(int rank, int nRanks)
|
||||
: rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()),
|
||||
barrierArr_(nRanks, 0), abortFlag_(nullptr)
|
||||
{
|
||||
}
|
||||
: rank_(rank),
|
||||
nRanks_(nRanks),
|
||||
netInitialized(false),
|
||||
peerCommAddresses_(nRanks, mscclppSocketAddress()),
|
||||
barrierArr_(nRanks, 0),
|
||||
abortFlag_(nullptr) {}
|
||||
|
||||
UniqueId Bootstrap::Impl::getUniqueId() const
|
||||
{
|
||||
UniqueId Bootstrap::Impl::getUniqueId() const {
|
||||
UniqueId ret;
|
||||
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
|
||||
return ret;
|
||||
}
|
||||
|
||||
UniqueId Bootstrap::Impl::createUniqueId()
|
||||
{
|
||||
UniqueId Bootstrap::Impl::createUniqueId() {
|
||||
netInit("");
|
||||
MSCCLPPTHROW(getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic)));
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
|
||||
@@ -139,18 +120,11 @@ UniqueId Bootstrap::Impl::createUniqueId()
|
||||
return getUniqueId();
|
||||
}
|
||||
|
||||
int Bootstrap::Impl::getRank()
|
||||
{
|
||||
return rank_;
|
||||
}
|
||||
int Bootstrap::Impl::getRank() { return rank_; }
|
||||
|
||||
int Bootstrap::Impl::getNranks()
|
||||
{
|
||||
return nRanks_;
|
||||
}
|
||||
int Bootstrap::Impl::getNranks() { return nRanks_; }
|
||||
|
||||
void Bootstrap::Impl::initialize(const UniqueId uniqueId)
|
||||
{
|
||||
void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
|
||||
netInit("");
|
||||
|
||||
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
|
||||
@@ -158,8 +132,7 @@ void Bootstrap::Impl::initialize(const UniqueId uniqueId)
|
||||
establishConnections();
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::initialize(std::string ipPortPair)
|
||||
{
|
||||
void Bootstrap::Impl::initialize(std::string ipPortPair) {
|
||||
netInit(ipPortPair);
|
||||
|
||||
uniqueId_.magic = 0xdeadbeef;
|
||||
@@ -173,16 +146,14 @@ void Bootstrap::Impl::initialize(std::string ipPortPair)
|
||||
establishConnections();
|
||||
}
|
||||
|
||||
Bootstrap::Impl::~Impl()
|
||||
{
|
||||
Bootstrap::Impl::~Impl() {
|
||||
if (rootThread_.joinable()) {
|
||||
rootThread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
|
||||
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank)
|
||||
{
|
||||
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank) {
|
||||
mscclppSocket sock;
|
||||
ExtInfo info;
|
||||
|
||||
@@ -195,14 +166,14 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<
|
||||
|
||||
if (this->nRanks_ != info.nRanks) {
|
||||
throw mscclpp::Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " +
|
||||
std::to_string(info.nRanks),
|
||||
mscclppInternalError);
|
||||
std::to_string(info.nRanks),
|
||||
ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
|
||||
throw mscclpp::Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) +
|
||||
" has already checked in",
|
||||
mscclppInternalError);
|
||||
" has already checked in",
|
||||
ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
// Save the connection handle for that rank
|
||||
@@ -212,8 +183,7 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
|
||||
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
|
||||
{
|
||||
const std::vector<mscclppSocketAddress>& rankAddressesRoot) {
|
||||
mscclppSocket sock;
|
||||
int next = (peer + 1) % this->nRanks_;
|
||||
MSCCLPPTHROW(mscclppSocketInit(&sock, &rankAddressesRoot[peer], this->uniqueId_.magic, mscclppSocketTypeBootstrap));
|
||||
@@ -222,21 +192,19 @@ void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocket
|
||||
MSCCLPPTHROW(mscclppSocketClose(&sock));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::bootstrapCreateRoot()
|
||||
{
|
||||
void Bootstrap::Impl::bootstrapCreateRoot() {
|
||||
mscclppSocket listenSock;
|
||||
|
||||
// mscclppSocket* listenSock = new mscclppSocket(); // TODO(saemal) make this a shared ptr
|
||||
MSCCLPPTHROW(
|
||||
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
|
||||
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
|
||||
MSCCLPPTHROW(mscclppSocketListen(&listenSock));
|
||||
MSCCLPPTHROW(mscclppSocketGetAddr(&listenSock, &uniqueId_.addr));
|
||||
auto lambda = [this, listenSock]() { this->bootstrapRoot(listenSock); };
|
||||
rootThread_ = std::thread(lambda);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::bootstrapRoot(mscclppSocket listenSock)
|
||||
{
|
||||
void Bootstrap::Impl::bootstrapRoot(mscclppSocket listenSock) {
|
||||
int numCollected = 0;
|
||||
std::vector<mscclppSocketAddress> rankAddresses(this->nRanks_, mscclppSocketAddress());
|
||||
// for initial rank <-> root information exchange
|
||||
@@ -264,24 +232,22 @@ void Bootstrap::Impl::bootstrapRoot(mscclppSocket listenSock)
|
||||
TRACE(MSCCLPP_INIT, "DONE");
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netInit(std::string ipPortPair)
|
||||
{
|
||||
if (netInitialized)
|
||||
return;
|
||||
void Bootstrap::Impl::netInit(std::string ipPortPair) {
|
||||
if (netInitialized) return;
|
||||
if (!ipPortPair.empty()) {
|
||||
mscclppSocketAddress remoteAddr;
|
||||
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
|
||||
throw mscclpp::Error(
|
||||
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>",
|
||||
mscclppInvalidArgument);
|
||||
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>",
|
||||
ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||
throw mscclpp::Error("NET/Socket : No usable listening interface found", mscclppInternalError);
|
||||
throw mscclpp::Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
|
||||
}
|
||||
} else {
|
||||
int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) {
|
||||
throw mscclpp::Error("Bootstrap : no socket interface found", mscclppInternalError);
|
||||
throw mscclpp::Error("Bootstrap : no socket interface found", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,8 +258,7 @@ void Bootstrap::Impl::netInit(std::string ipPortPair)
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::establishConnections()
|
||||
{
|
||||
void Bootstrap::Impl::establishConnections() {
|
||||
mscclppSocketAddress nextAddr;
|
||||
mscclppSocket sock, listenSockRoot;
|
||||
ExtInfo info;
|
||||
@@ -344,7 +309,7 @@ void Bootstrap::Impl::establishConnections()
|
||||
MSCCLPPTHROW(mscclppSocketClose(&listenSockRoot));
|
||||
|
||||
MSCCLPPTHROW(
|
||||
mscclppSocketInit(&this->ringSendSocket_, &nextAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
|
||||
mscclppSocketInit(&this->ringSendSocket_, &nextAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
|
||||
MSCCLPPTHROW(mscclppSocketConnect(&this->ringSendSocket_));
|
||||
// Accept the connect request from the previous rank in the AllGather ring
|
||||
MSCCLPPTHROW(mscclppSocketInit(&this->ringRecvSocket_));
|
||||
@@ -357,8 +322,7 @@ void Bootstrap::Impl::establishConnections()
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::allGather(void* allData, int size)
|
||||
{
|
||||
void Bootstrap::Impl::allGather(void* allData, int size) {
|
||||
char* data = static_cast<char*>(allData);
|
||||
int rank = this->rank_;
|
||||
int nRanks = this->nRanks_;
|
||||
@@ -382,26 +346,23 @@ void Bootstrap::Impl::allGather(void* allData, int size)
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size)
|
||||
{
|
||||
void Bootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size) {
|
||||
MSCCLPPTHROW(mscclppSocketSend(sock, &size, sizeof(int)));
|
||||
MSCCLPPTHROW(mscclppSocketSend(sock, const_cast<void*>(data), size));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
|
||||
{
|
||||
void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) {
|
||||
int recvSize;
|
||||
MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
|
||||
if (recvSize > size) {
|
||||
throw mscclpp::Error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " +
|
||||
std::to_string(size),
|
||||
mscclppInternalError);
|
||||
throw mscclpp::Error(
|
||||
"Message truncated : received " + std::to_string(recvSize) + " bytes instead of " + std::to_string(size),
|
||||
ErrorCode::InvalidUsage);
|
||||
}
|
||||
MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::send(void* data, int size, int peer, int tag)
|
||||
{
|
||||
void Bootstrap::Impl::send(void* data, int size, int peer, int tag) {
|
||||
mscclppSocket sock;
|
||||
MSCCLPPTHROW(mscclppSocketInit(&sock, &this->peerCommAddresses_[peer], this->uniqueId_.magic,
|
||||
mscclppSocketTypeBootstrap, this->abortFlag_));
|
||||
@@ -413,8 +374,7 @@ void Bootstrap::Impl::send(void* data, int size, int peer, int tag)
|
||||
MSCCLPPTHROW(mscclppSocketClose(&sock));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::recv(void* data, int size, int peer, int tag)
|
||||
{
|
||||
void Bootstrap::Impl::recv(void* data, int size, int peer, int tag) {
|
||||
// search over all unexpected messages
|
||||
auto lambda = [peer, tag](const UnexpectedMsg& msg) { return msg.peer == peer && msg.tag == tag; };
|
||||
auto it = std::find_if(unexpectedMessages_.begin(), unexpectedMessages_.end(), lambda);
|
||||
@@ -443,623 +403,37 @@ void Bootstrap::Impl::recv(void* data, int size, int peer, int tag)
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::barrier()
|
||||
{
|
||||
allGather(barrierArr_.data(), sizeof(int));
|
||||
}
|
||||
void Bootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
|
||||
|
||||
void Bootstrap::Impl::close()
|
||||
{
|
||||
void Bootstrap::Impl::close() {
|
||||
MSCCLPPTHROW(mscclppSocketClose(&this->listenSock_));
|
||||
MSCCLPPTHROW(mscclppSocketClose(&this->ringSendSocket_));
|
||||
MSCCLPPTHROW(mscclppSocketClose(&this->ringRecvSocket_));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Bootstrap::Bootstrap(int rank, int nRanks)
|
||||
{
|
||||
MSCCLPP_API_CPP Bootstrap::Bootstrap(int rank, int nRanks) {
|
||||
// pimpl_ = std::make_unique<Impl>(ipPortPair, rank, nRanks, uniqueId);
|
||||
pimpl_ = std::make_unique<Impl>(rank, nRanks);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::createUniqueId()
|
||||
{
|
||||
return pimpl_->createUniqueId();
|
||||
}
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
|
||||
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const
|
||||
{
|
||||
return pimpl_->getUniqueId();
|
||||
}
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getRank()
|
||||
{
|
||||
return pimpl_->getRank();
|
||||
}
|
||||
MSCCLPP_API_CPP int Bootstrap::getRank() { return pimpl_->getRank(); }
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getNranks()
|
||||
{
|
||||
return pimpl_->getNranks();
|
||||
}
|
||||
MSCCLPP_API_CPP int Bootstrap::getNranks() { return pimpl_->getNranks(); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag)
|
||||
{
|
||||
pimpl_->send(data, size, peer, tag);
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::recv(void* data, int size, int peer, int tag)
|
||||
{
|
||||
pimpl_->recv(data, size, peer, tag);
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::recv(void* data, int size, int peer, int tag) { pimpl_->recv(data, size, peer, tag); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::allGather(void* allData, int size)
|
||||
{
|
||||
pimpl_->allGather(allData, size);
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(UniqueId uniqueId)
|
||||
{
|
||||
pimpl_->initialize(uniqueId);
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair)
|
||||
{
|
||||
pimpl_->initialize(ipPortPair);
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->initialize(ipPortPair); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::barrier()
|
||||
{
|
||||
pimpl_->barrier();
|
||||
}
|
||||
MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); }
|
||||
|
||||
MSCCLPP_API_CPP Bootstrap::~Bootstrap()
|
||||
{
|
||||
pimpl_->close();
|
||||
}
|
||||
|
||||
// ------------------- Old bootstrap functions -------------------
|
||||
struct BootstrapRootArgs
|
||||
{
|
||||
struct mscclppSocket* listenSock;
|
||||
uint64_t magic;
|
||||
};
|
||||
|
||||
/* Init functions */
|
||||
static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1];
|
||||
static union mscclppSocketAddress bootstrapNetIfAddr;
|
||||
static int bootstrapNetInitDone = 0;
|
||||
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
mscclppResult_t bootstrapNetInit(const char* ip_port_pair)
|
||||
{
|
||||
if (bootstrapNetInitDone == 0) {
|
||||
pthread_mutex_lock(&bootstrapNetLock);
|
||||
if (bootstrapNetInitDone == 0) {
|
||||
const char* env;
|
||||
if (ip_port_pair) {
|
||||
env = ip_port_pair;
|
||||
} else {
|
||||
env = getenv("MSCCLPP_COMM_ID");
|
||||
}
|
||||
if (env) {
|
||||
union mscclppSocketAddress remoteAddr;
|
||||
if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) {
|
||||
WARN("Invalid MSCCLPP_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
if (mscclppFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE,
|
||||
1) <= 0) {
|
||||
WARN("NET/Socket : No usable listening interface found");
|
||||
return mscclppSystemError;
|
||||
}
|
||||
} else {
|
||||
int nIfs = mscclppFindInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1);
|
||||
if (nIfs <= 0) {
|
||||
WARN("Bootstrap : no socket interface found");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
|
||||
sprintf(line, " %s:", bootstrapNetIfName);
|
||||
mscclppSocketToString(&bootstrapNetIfAddr, line + strlen(line));
|
||||
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
|
||||
bootstrapNetInitDone = 1;
|
||||
}
|
||||
pthread_mutex_unlock(&bootstrapNetLock);
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// Additional sync functions
|
||||
static mscclppResult_t bootstrapNetSend(struct mscclppSocket* sock, void* data, int size)
|
||||
{
|
||||
MSCCLPPCHECK(mscclppSocketSend(sock, &size, sizeof(int)));
|
||||
MSCCLPPCHECK(mscclppSocketSend(sock, data, size));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
static mscclppResult_t bootstrapNetRecv(struct mscclppSocket* sock, void* data, int size)
|
||||
{
|
||||
int recvSize;
|
||||
MSCCLPPCHECK(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
|
||||
if (recvSize > size) {
|
||||
WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
MSCCLPPCHECK(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// struct ExtInfo
|
||||
// {
|
||||
// int rank;
|
||||
// int nranks;
|
||||
// union mscclppSocketAddress extAddressListenRoot;
|
||||
// union mscclppSocketAddress extAddressListen;
|
||||
// };
|
||||
|
||||
#include <sys/resource.h>
|
||||
|
||||
// static mscclppResult_t setFilesLimit()
|
||||
// {
|
||||
// struct rlimit filesLimit;
|
||||
// SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit");
|
||||
// filesLimit.rlim_cur = filesLimit.rlim_max;
|
||||
// SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit");
|
||||
// return mscclppSuccess;
|
||||
// }
|
||||
|
||||
static void* bootstrapRoot(void* rargs)
|
||||
{
|
||||
struct BootstrapRootArgs* args = (struct BootstrapRootArgs*)rargs;
|
||||
struct mscclppSocket* listenSock = args->listenSock;
|
||||
uint64_t magic = args->magic;
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
int nranks = 0, c = 0;
|
||||
struct ExtInfo info;
|
||||
union mscclppSocketAddress* rankAddresses = NULL;
|
||||
union mscclppSocketAddress* rankAddressesRoot = NULL; // for initial rank <-> root information exchange
|
||||
union mscclppSocketAddress* zero = NULL;
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&zero, 1), res, out);
|
||||
setFilesLimit();
|
||||
|
||||
TRACE(MSCCLPP_INIT, "BEGIN");
|
||||
/* Receive addresses from all ranks */
|
||||
do {
|
||||
struct mscclppSocket sock;
|
||||
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), res, out);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, listenSock), res, out);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out);
|
||||
|
||||
if (c == 0) {
|
||||
nranks = info.nRanks;
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddresses, nranks), res, out);
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddressesRoot, nranks), res, out);
|
||||
}
|
||||
|
||||
if (nranks != info.nRanks) {
|
||||
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nRanks);
|
||||
goto out;
|
||||
}
|
||||
|
||||
if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union mscclppSocketAddress)) != 0) {
|
||||
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
|
||||
goto out;
|
||||
}
|
||||
|
||||
// Save the connection handle for that rank
|
||||
memcpy(rankAddressesRoot + info.rank, &info.extAddressListenRoot, sizeof(union mscclppSocketAddress));
|
||||
memcpy(rankAddresses + info.rank, &info.extAddressListen, sizeof(union mscclppSocketAddress));
|
||||
|
||||
++c;
|
||||
TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks);
|
||||
} while (c < nranks);
|
||||
TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", nranks);
|
||||
|
||||
// Send the connect handle for the next rank in the AllGather ring
|
||||
for (int r = 0; r < nranks; ++r) {
|
||||
int next = (r + 1) % nranks;
|
||||
struct mscclppSocket sock;
|
||||
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, rankAddressesRoot + r, magic, mscclppSocketTypeBootstrap), res, out);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), res, out);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, rankAddresses + next, sizeof(union mscclppSocketAddress)), res, out);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out);
|
||||
}
|
||||
TRACE(MSCCLPP_INIT, "SENT OUT ALL %d HANDLES", nranks);
|
||||
|
||||
out:
|
||||
if (listenSock != NULL) {
|
||||
mscclppSocketClose(listenSock);
|
||||
free(listenSock);
|
||||
}
|
||||
if (rankAddresses)
|
||||
free(rankAddresses);
|
||||
if (rankAddressesRoot)
|
||||
free(rankAddressesRoot);
|
||||
if (zero)
|
||||
free(zero);
|
||||
free(rargs);
|
||||
|
||||
TRACE(MSCCLPP_INIT, "DONE");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle)
|
||||
{
|
||||
struct mscclppSocket* listenSock;
|
||||
struct BootstrapRootArgs* args;
|
||||
pthread_t thread;
|
||||
|
||||
MSCCLPPCHECK(mscclppCalloc(&listenSock, 1));
|
||||
MSCCLPPCHECK(mscclppSocketInit(listenSock, &handle->addr, handle->magic, mscclppSocketTypeBootstrap, NULL, 0));
|
||||
MSCCLPPCHECK(mscclppSocketListen(listenSock));
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(listenSock, &handle->addr));
|
||||
|
||||
MSCCLPPCHECK(mscclppCalloc(&args, 1));
|
||||
args->listenSock = listenSock;
|
||||
args->magic = handle->magic;
|
||||
NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0);
|
||||
mscclppSetThreadName(thread, "MSCCLPP BootstrapR");
|
||||
NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// #include <netinet/in.h>
|
||||
// #include <arpa/inet.h>
|
||||
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, const char* ip_port_pair)
|
||||
{
|
||||
memset(handle, 0, sizeof(mscclppBootstrapHandle));
|
||||
const char* env = NULL;
|
||||
|
||||
if (ip_port_pair) {
|
||||
env = ip_port_pair;
|
||||
} else {
|
||||
env = getenv("MSCCLPP_COMM_ID");
|
||||
}
|
||||
if (env) {
|
||||
handle->magic = 0xdeadbeef;
|
||||
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env);
|
||||
if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) {
|
||||
WARN("Invalid MSCCLPP_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
if (isRoot)
|
||||
MSCCLPPCHECK(bootstrapCreateRoot(handle));
|
||||
} else {
|
||||
MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic)));
|
||||
memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress));
|
||||
MSCCLPPCHECK(bootstrapCreateRoot(handle));
|
||||
}
|
||||
printf("addr = %s port = %d\n", inet_ntoa(handle->addr.sin.sin_addr), (int)ntohs(handle->addr.sin.sin_port));
|
||||
// printf("addr = %s\n", inet_ntoa((*(struct sockaddr_in*)&handle->addr.sa).sin_addr));
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct UnexConn
|
||||
{
|
||||
int peer;
|
||||
int tag;
|
||||
struct mscclppSocket sock;
|
||||
struct UnexConn* next;
|
||||
};
|
||||
|
||||
struct BootstrapState
|
||||
{
|
||||
struct mscclppSocket listenSock;
|
||||
struct mscclppSocket ringRecvSocket;
|
||||
struct mscclppSocket ringSendSocket;
|
||||
union mscclppSocketAddress* peerCommAddresses;
|
||||
union mscclppSocketAddress* peerProxyAddresses;
|
||||
struct UnexConn* unexpectedConnections;
|
||||
int cudaDev;
|
||||
int rank;
|
||||
int nranks;
|
||||
uint64_t magic;
|
||||
volatile uint32_t* abortFlag;
|
||||
};
|
||||
|
||||
mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm)
|
||||
{
|
||||
int rank = comm->rank;
|
||||
int nranks = comm->nRanks;
|
||||
struct BootstrapState* state;
|
||||
struct mscclppSocket* proxySocket;
|
||||
mscclppSocketAddress nextAddr;
|
||||
struct mscclppSocket sock, listenSockRoot;
|
||||
struct ExtInfo info;
|
||||
|
||||
MSCCLPPCHECK(mscclppCalloc(&state, 1));
|
||||
state->rank = rank;
|
||||
state->nranks = nranks;
|
||||
state->abortFlag = comm->abortFlag;
|
||||
comm->bootstrap = state;
|
||||
comm->magic = state->magic = handle->magic;
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks);
|
||||
|
||||
info.rank = rank;
|
||||
info.nRanks = nranks;
|
||||
|
||||
// Create socket for other ranks to contact me
|
||||
MSCCLPPCHECK(mscclppSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap,
|
||||
comm->abortFlag));
|
||||
MSCCLPPCHECK(mscclppSocketListen(&state->listenSock));
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, &info.extAddressListen));
|
||||
|
||||
// Create socket for root to contact me
|
||||
MSCCLPPCHECK(
|
||||
mscclppSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
|
||||
MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot));
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot));
|
||||
|
||||
// stagger connection times to avoid an overload of the root
|
||||
if (nranks > 128) {
|
||||
long msec = rank;
|
||||
struct timespec tv;
|
||||
tv.tv_sec = msec / 1000;
|
||||
tv.tv_nsec = 1000000 * (msec % 1000);
|
||||
TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, msec);
|
||||
(void)nanosleep(&tv, NULL);
|
||||
}
|
||||
|
||||
// send info on my listening socket to root
|
||||
MSCCLPPCHECK(mscclppSocketInit(&sock, &handle->addr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
|
||||
MSCCLPPCHECK(mscclppSocketConnect(&sock));
|
||||
MSCCLPPCHECK(bootstrapNetSend(&sock, &info, sizeof(info)));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&sock));
|
||||
|
||||
// get info on my "next" rank in the bootstrap ring from root
|
||||
MSCCLPPCHECK(mscclppSocketInit(&sock));
|
||||
MSCCLPPCHECK(mscclppSocketAccept(&sock, &listenSockRoot));
|
||||
MSCCLPPCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union mscclppSocketAddress)));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&sock));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&listenSockRoot));
|
||||
|
||||
MSCCLPPCHECK(
|
||||
mscclppSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag));
|
||||
MSCCLPPCHECK(mscclppSocketConnect(&state->ringSendSocket));
|
||||
// Accept the connect request from the previous rank in the AllGather ring
|
||||
MSCCLPPCHECK(mscclppSocketInit(&state->ringRecvSocket));
|
||||
MSCCLPPCHECK(mscclppSocketAccept(&state->ringRecvSocket, &state->listenSock));
|
||||
|
||||
// AllGather all listen handlers
|
||||
MSCCLPPCHECK(mscclppCalloc(&state->peerCommAddresses, nranks));
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, state->peerCommAddresses + rank));
|
||||
MSCCLPPCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union mscclppSocketAddress)));
|
||||
|
||||
// Create the service proxy
|
||||
MSCCLPPCHECK(mscclppCalloc(&state->peerProxyAddresses, nranks));
|
||||
|
||||
// proxy is aborted through a message; don't set abortFlag
|
||||
MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1));
|
||||
MSCCLPPCHECK(
|
||||
mscclppSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeProxy, comm->abortFlag));
|
||||
MSCCLPPCHECK(mscclppSocketListen(proxySocket));
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, state->peerProxyAddresses + rank));
|
||||
MSCCLPPCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union mscclppSocketAddress)));
|
||||
// MSCCLPPCHECK(mscclppProxyInit(comm, proxySocket, state->peerProxyAddresses));
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks);
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size)
|
||||
{
|
||||
struct BootstrapState* state = (struct BootstrapState*)commState;
|
||||
char* data = (char*)allData;
|
||||
int rank = state->rank;
|
||||
int nranks = state->nranks;
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nranks, size);
|
||||
|
||||
/* Simple ring based AllGather
|
||||
* At each step i receive data from (rank-i-1) from left
|
||||
* and send previous step's data from (rank-i) to right
|
||||
*/
|
||||
for (int i = 0; i < nranks - 1; i++) {
|
||||
size_t rslice = (rank - i - 1 + nranks) % nranks;
|
||||
size_t sslice = (rank - i + nranks) % nranks;
|
||||
|
||||
// Send slice to the right
|
||||
MSCCLPPCHECK(bootstrapNetSend(&state->ringSendSocket, data + sslice * size, size));
|
||||
// Recv slice from the left
|
||||
MSCCLPPCHECK(bootstrapNetRecv(&state->ringRecvSocket, data + rslice * size, size));
|
||||
}
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size)
|
||||
{
|
||||
mscclppResult_t ret = mscclppSuccess;
|
||||
struct BootstrapState* state = (struct BootstrapState*)commState;
|
||||
struct mscclppSocket sock;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, state->peerCommAddresses + peer, state->magic, mscclppSocketTypeBootstrap,
|
||||
state->abortFlag),
|
||||
ret, fail);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), ret, fail);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail);
|
||||
|
||||
exit:
|
||||
MSCCLPPCHECK(mscclppSocketClose(&sock));
|
||||
return ret;
|
||||
fail:
|
||||
goto exit;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag)
|
||||
{
|
||||
if (nranks == 1)
|
||||
return mscclppSuccess;
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag);
|
||||
|
||||
/* Simple intra process barrier
|
||||
*
|
||||
* Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
|
||||
* "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
|
||||
*/
|
||||
int data[1];
|
||||
for (int mask = 1; mask < nranks; mask <<= 1) {
|
||||
int src = (rank - mask + nranks) % nranks;
|
||||
int dst = (rank + mask) % nranks;
|
||||
MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data)));
|
||||
MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data)));
|
||||
}
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size)
|
||||
{
|
||||
if (nranks == 1)
|
||||
return mscclppSuccess;
|
||||
char* data = (char*)allData;
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - ENTER", rank, nranks, size);
|
||||
|
||||
for (int i = 1; i < nranks; i++) {
|
||||
int src = (rank - i + nranks) % nranks;
|
||||
int dst = (rank + i) % nranks;
|
||||
MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], /*tag=*/i, data + rank * size, size));
|
||||
MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], /*tag=*/i, data + src * size, size));
|
||||
}
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t unexpectedEnqueue(struct BootstrapState* state, int peer, int tag, struct mscclppSocket* sock)
|
||||
{
|
||||
// New unex
|
||||
struct UnexConn* unex;
|
||||
MSCCLPPCHECK(mscclppCalloc(&unex, 1));
|
||||
unex->peer = peer;
|
||||
unex->tag = tag;
|
||||
memcpy(&unex->sock, sock, sizeof(struct mscclppSocket));
|
||||
|
||||
// Enqueue
|
||||
struct UnexConn* list = state->unexpectedConnections;
|
||||
if (list == NULL) {
|
||||
state->unexpectedConnections = unex;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
while (list->next)
|
||||
list = list->next;
|
||||
list->next = unex;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t unexpectedDequeue(struct BootstrapState* state, int peer, int tag, struct mscclppSocket* sock,
|
||||
int* found)
|
||||
{
|
||||
struct UnexConn* elem = state->unexpectedConnections;
|
||||
struct UnexConn* prev = NULL;
|
||||
*found = 0;
|
||||
while (elem) {
|
||||
if (elem->peer == peer && elem->tag == tag) {
|
||||
if (prev == NULL) {
|
||||
state->unexpectedConnections = elem->next;
|
||||
} else {
|
||||
prev->next = elem->next;
|
||||
}
|
||||
memcpy(sock, &elem->sock, sizeof(struct mscclppSocket));
|
||||
free(elem);
|
||||
*found = 1;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
prev = elem;
|
||||
elem = elem->next;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static void unexpectedFree(struct BootstrapState* state)
|
||||
{
|
||||
struct UnexConn* elem = state->unexpectedConnections;
|
||||
struct UnexConn* prev = NULL;
|
||||
|
||||
while (elem) {
|
||||
prev = elem;
|
||||
elem = elem->next;
|
||||
free(prev);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// We can't know who we'll receive from, so we need to receive everything at once
|
||||
mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size)
|
||||
{
|
||||
mscclppResult_t ret = mscclppSuccess;
|
||||
struct BootstrapState* state = (struct BootstrapState*)commState;
|
||||
struct mscclppSocket sock;
|
||||
int newPeer, newTag;
|
||||
|
||||
// Search unexpected connections first
|
||||
int found;
|
||||
MSCCLPPCHECK(unexpectedDequeue(state, peer, tag, &sock, &found));
|
||||
if (found) {
|
||||
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
|
||||
goto exit;
|
||||
}
|
||||
|
||||
// Then look for new connections
|
||||
while (1) {
|
||||
MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), ret, fail);
|
||||
MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, &state->listenSock), ret, fail);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail);
|
||||
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail);
|
||||
if (newPeer == peer && newTag == tag) {
|
||||
MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
|
||||
goto exit;
|
||||
}
|
||||
// Unexpected connection. Save for later.
|
||||
MSCCLPPCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail);
|
||||
}
|
||||
exit:
|
||||
MSCCLPPCHECK(mscclppSocketClose(&sock));
|
||||
return ret;
|
||||
fail:
|
||||
goto exit;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapClose(void* commState)
|
||||
{
|
||||
struct BootstrapState* state = (struct BootstrapState*)commState;
|
||||
if (state->unexpectedConnections != NULL) {
|
||||
unexpectedFree(state);
|
||||
if (*state->abortFlag == 0) {
|
||||
WARN("Unexpected connections are not empty");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->listenSock));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket));
|
||||
|
||||
free(state->peerCommAddresses);
|
||||
free(state);
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t bootstrapAbort(void* commState)
|
||||
{
|
||||
struct BootstrapState* state = (struct BootstrapState*)commState;
|
||||
if (commState == NULL)
|
||||
return mscclppSuccess;
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->listenSock));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket));
|
||||
MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket));
|
||||
free(state->peerCommAddresses);
|
||||
free(state->peerProxyAddresses);
|
||||
free(state);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); }
|
||||
|
||||
@@ -5,25 +5,23 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "socket.h"
|
||||
#include "config.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "config.h"
|
||||
#include "utils.h"
|
||||
|
||||
static mscclppResult_t socketProgressOpt(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset,
|
||||
int block, int* closed)
|
||||
{
|
||||
int block, int* closed) {
|
||||
int bytes = 0;
|
||||
*closed = 0;
|
||||
char* data = (char*)ptr;
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
do {
|
||||
if (op == MSCCLPP_SOCKET_RECV)
|
||||
bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == MSCCLPP_SOCKET_RECV) bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == MSCCLPP_SOCKET_SEND)
|
||||
bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
|
||||
if (op == MSCCLPP_SOCKET_RECV && bytes == 0) {
|
||||
@@ -48,8 +46,7 @@ static mscclppResult_t socketProgressOpt(int op, struct mscclppSocket* sock, voi
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset)
|
||||
{
|
||||
static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) {
|
||||
int closed;
|
||||
MSCCLPPCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed));
|
||||
if (closed) {
|
||||
@@ -60,10 +57,8 @@ static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void*
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset)
|
||||
{
|
||||
while (*offset < size)
|
||||
MSCCLPPCHECK(socketProgress(op, sock, ptr, size, offset));
|
||||
static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) {
|
||||
while (*offset < size) MSCCLPPCHECK(socketProgress(op, sock, ptr, size, offset));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -71,10 +66,8 @@ static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr,
|
||||
*
|
||||
* Output: "IPv4/IPv6 address<port>"
|
||||
*/
|
||||
const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/)
|
||||
{
|
||||
if (buf == NULL || addr == NULL)
|
||||
return NULL;
|
||||
const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) {
|
||||
if (buf == NULL || addr == NULL) return NULL;
|
||||
struct sockaddr* saddr = &addr->sa;
|
||||
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
|
||||
buf[0] = '\0';
|
||||
@@ -90,68 +83,58 @@ const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, c
|
||||
return buf;
|
||||
}
|
||||
|
||||
static uint16_t socketToPort(union mscclppSocketAddress* addr)
|
||||
{
|
||||
static uint16_t socketToPort(union mscclppSocketAddress* addr) {
|
||||
struct sockaddr* saddr = &addr->sa;
|
||||
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
|
||||
}
|
||||
|
||||
/* Allow the user to force the IPv4/IPv6 interface selection */
|
||||
static int envSocketFamily(void)
|
||||
{
|
||||
int family = -1; // Family selection is not forced, will use first one found
|
||||
static int envSocketFamily(void) {
|
||||
int family = -1; // Family selection is not forced, will use first one found
|
||||
char* env = getenv("MSCCLPP_SOCKET_FAMILY");
|
||||
if (env == NULL)
|
||||
return family;
|
||||
if (env == NULL) return family;
|
||||
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_SOCKET_FAMILY set by environment to %s", env);
|
||||
|
||||
if (strcmp(env, "AF_INET") == 0)
|
||||
family = AF_INET; // IPv4
|
||||
family = AF_INET; // IPv4
|
||||
else if (strcmp(env, "AF_INET6") == 0)
|
||||
family = AF_INET6; // IPv6
|
||||
family = AF_INET6; // IPv6
|
||||
return family;
|
||||
}
|
||||
|
||||
static int findInterfaces(const char* prefixList, char* names, union mscclppSocketAddress* addrs, int sock_family,
|
||||
int maxIfNameSize, int maxIfs)
|
||||
{
|
||||
int maxIfNameSize, int maxIfs) {
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
struct netIf userIfs[MAX_IFS];
|
||||
bool searchNot = prefixList && prefixList[0] == '^';
|
||||
if (searchNot)
|
||||
prefixList++;
|
||||
if (searchNot) prefixList++;
|
||||
bool searchExact = prefixList && prefixList[0] == '=';
|
||||
if (searchExact)
|
||||
prefixList++;
|
||||
if (searchExact) prefixList++;
|
||||
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
|
||||
|
||||
int found = 0;
|
||||
struct ifaddrs *interfaces, *interface;
|
||||
getifaddrs(&interfaces);
|
||||
for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
|
||||
if (interface->ifa_addr == NULL)
|
||||
continue;
|
||||
if (interface->ifa_addr == NULL) continue;
|
||||
|
||||
/* We only support IPv4 & IPv6 */
|
||||
int family = interface->ifa_addr->sa_family;
|
||||
if (family != AF_INET && family != AF_INET6)
|
||||
continue;
|
||||
if (family != AF_INET && family != AF_INET6) continue;
|
||||
|
||||
TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Found interface %s:%s", interface->ifa_name,
|
||||
mscclppSocketToString((union mscclppSocketAddress*)interface->ifa_addr, line));
|
||||
|
||||
/* Allow the caller to force the socket family type */
|
||||
if (sock_family != -1 && family != sock_family)
|
||||
continue;
|
||||
if (sock_family != -1 && family != sock_family) continue;
|
||||
|
||||
/* We also need to skip IPv6 loopback interfaces */
|
||||
if (family == AF_INET6) {
|
||||
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
|
||||
if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr))
|
||||
continue;
|
||||
if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue;
|
||||
}
|
||||
|
||||
// check against user specified interfaces
|
||||
@@ -183,8 +166,7 @@ static int findInterfaces(const char* prefixList, char* names, union mscclppSock
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* remote)
|
||||
{
|
||||
static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* remote) {
|
||||
/* Check family first */
|
||||
int family = local_if.ifa_addr->sa_family;
|
||||
if (family != remote->sa.sa_family) {
|
||||
@@ -207,8 +189,8 @@ static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* rem
|
||||
struct in6_addr& mask_in6 = mask->sin6_addr;
|
||||
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
|
||||
bool same = true;
|
||||
int len = 16; // IPv6 address is 16 unsigned char
|
||||
for (int c = 0; c < len; c++) { // Network byte order is big-endian
|
||||
int len = 16; // IPv6 address is 16 unsigned char
|
||||
for (int c = 0; c < len; c++) { // Network byte order is big-endian
|
||||
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
|
||||
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
|
||||
if (c1 ^ c2) {
|
||||
@@ -228,8 +210,7 @@ static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* rem
|
||||
}
|
||||
|
||||
int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* localAddrs,
|
||||
union mscclppSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs)
|
||||
{
|
||||
union mscclppSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
@@ -238,13 +219,11 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l
|
||||
struct ifaddrs *interfaces, *interface;
|
||||
getifaddrs(&interfaces);
|
||||
for (interface = interfaces; interface && !found; interface = interface->ifa_next) {
|
||||
if (interface->ifa_addr == NULL)
|
||||
continue;
|
||||
if (interface->ifa_addr == NULL) continue;
|
||||
|
||||
/* We only support IPv4 & IPv6 */
|
||||
int family = interface->ifa_addr->sa_family;
|
||||
if (family != AF_INET && family != AF_INET6)
|
||||
continue;
|
||||
if (family != AF_INET && family != AF_INET6) continue;
|
||||
|
||||
// check against user specified interfaces
|
||||
if (!matchSubnet(*interface, remoteAddr)) {
|
||||
@@ -262,8 +241,7 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l
|
||||
interface->ifa_name, mscclppSocketToString(localAddrs + found, line),
|
||||
mscclppSocketToString(remoteAddr, line_a));
|
||||
found++;
|
||||
if (found == maxIfs)
|
||||
break;
|
||||
if (found == maxIfs) break;
|
||||
}
|
||||
|
||||
if (found == 0) {
|
||||
@@ -273,8 +251,7 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l
|
||||
return found;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, const char* ip_port_pair)
|
||||
{
|
||||
mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, const char* ip_port_pair) {
|
||||
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
|
||||
WARN("Net : string is null");
|
||||
return mscclppInvalidArgument;
|
||||
@@ -305,36 +282,34 @@ mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, c
|
||||
if (p->ai_family == AF_INET) {
|
||||
struct sockaddr_in& sin = ua->sin;
|
||||
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
|
||||
sin.sin_family = AF_INET; // IPv4
|
||||
sin.sin_family = AF_INET; // IPv4
|
||||
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
|
||||
sin.sin_port = htons(ni.port); // port
|
||||
sin.sin_port = htons(ni.port); // port
|
||||
} else if (p->ai_family == AF_INET6) {
|
||||
struct sockaddr_in6& sin6 = ua->sin6;
|
||||
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
sin6.sin6_port = htons(ni.port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = 0; // should be global scope, set to 0
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
sin6.sin6_port = htons(ni.port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = 0; // should be global scope, set to 0
|
||||
} else {
|
||||
WARN("Net : unsupported IP family");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
|
||||
freeaddrinfo(p); // all done with this structure
|
||||
freeaddrinfo(p); // all done with this structure
|
||||
|
||||
} else {
|
||||
int i, j = -1, len = strlen(ip_port_pair);
|
||||
for (i = 1; i < len; i++) {
|
||||
if (ip_port_pair[i] == '%')
|
||||
j = i;
|
||||
if (ip_port_pair[i] == ']')
|
||||
break;
|
||||
if (ip_port_pair[i] == '%') j = i;
|
||||
if (ip_port_pair[i] == ']') break;
|
||||
}
|
||||
if (i == len) {
|
||||
WARN("Net : No valid [IPv6]:port pair found");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
|
||||
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
|
||||
|
||||
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
|
||||
memset(ip_str, '\0', sizeof(ip_str));
|
||||
@@ -343,21 +318,19 @@ mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, c
|
||||
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
|
||||
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
|
||||
int port = atoi(port_str);
|
||||
if (!global_scope)
|
||||
strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
|
||||
if (!global_scope) strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
|
||||
|
||||
struct sockaddr_in6& sin6 = ua->sin6;
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
|
||||
sin6.sin6_port = htons(port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
|
||||
sin6.sin6_port = htons(port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs)
|
||||
{
|
||||
int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
|
||||
static int shownIfName = 0;
|
||||
int nIfs = 0;
|
||||
// Allow user to force the INET socket family selection
|
||||
@@ -367,8 +340,7 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, in
|
||||
if (env && strlen(env) > 1) {
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_SOCKET_IFNAME set by environment to %s", env);
|
||||
// Specified by user : find or fail
|
||||
if (shownIfName++ == 0)
|
||||
INFO(MSCCLPP_NET, "MSCCLPP_SOCKET_IFNAME set to %s", env);
|
||||
if (shownIfName++ == 0) INFO(MSCCLPP_NET, "MSCCLPP_SOCKET_IFNAME set to %s", env);
|
||||
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
} else {
|
||||
// Try to automatically pick the right one
|
||||
@@ -386,19 +358,15 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, in
|
||||
}
|
||||
}
|
||||
// Then look for anything else (but not docker or lo)
|
||||
if (nIfs == 0)
|
||||
nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
// Finally look for docker, then lo.
|
||||
if (nIfs == 0)
|
||||
nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
if (nIfs == 0)
|
||||
nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
}
|
||||
return nIfs;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock)
|
||||
{
|
||||
mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock) {
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketListen: pass NULL socket");
|
||||
return mscclppInvalidArgument;
|
||||
@@ -438,20 +406,17 @@ mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketGetAddr(struct mscclppSocket* sock, union mscclppSocketAddress* addr)
|
||||
{
|
||||
mscclppResult_t mscclppSocketGetAddr(struct mscclppSocket* sock, union mscclppSocketAddress* addr) {
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketGetAddr: pass NULL socket");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
if (sock->state != mscclppSocketStateReady)
|
||||
return mscclppInternalError;
|
||||
if (sock->state != mscclppSocketStateReady) return mscclppInternalError;
|
||||
memcpy(addr, &sock->addr, sizeof(union mscclppSocketAddress));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketTryAccept(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketTryAccept(struct mscclppSocket* sock) {
|
||||
static bool timeInitialized = false;
|
||||
static mscclppTime_t initTime;
|
||||
if (!timeInitialized) {
|
||||
@@ -482,14 +447,12 @@ static mscclppResult_t socketTryAccept(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock) {
|
||||
uint64_t magic;
|
||||
enum mscclppSocketType type;
|
||||
int received = 0;
|
||||
MSCCLPPCHECK(mscclppSocketProgress(MSCCLPP_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
|
||||
if (received == 0)
|
||||
return mscclppSuccess;
|
||||
if (received == 0) return mscclppSuccess;
|
||||
MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
|
||||
if (magic != sock->magic) {
|
||||
WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
|
||||
@@ -514,8 +477,7 @@ static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketStartConnect(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketStartConnect(struct mscclppSocket* sock) {
|
||||
static bool timeInitialized = false;
|
||||
static mscclppTime_t initTime;
|
||||
if (!timeInitialized) {
|
||||
@@ -543,8 +505,7 @@ static mscclppResult_t socketStartConnect(struct mscclppSocket* sock)
|
||||
return mscclppRemoteError;
|
||||
}
|
||||
usleep(SLEEP_INT);
|
||||
if (++sock->connectRetries % 1000 == 0)
|
||||
INFO(MSCCLPP_ALL, "Call to connect returned %s, retrying", strerror(errno));
|
||||
if (++sock->connectRetries % 1000 == 0) INFO(MSCCLPP_ALL, "Call to connect returned %s, retrying", strerror(errno));
|
||||
return mscclppSuccess;
|
||||
} else {
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
@@ -555,8 +516,7 @@ static mscclppResult_t socketStartConnect(struct mscclppSocket* sock)
|
||||
}
|
||||
}
|
||||
|
||||
static mscclppResult_t socketPollConnect(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketPollConnect(struct mscclppSocket* sock) {
|
||||
static bool timeInitialized = false;
|
||||
static mscclppTime_t initTime;
|
||||
if (!timeInitialized) {
|
||||
@@ -608,8 +568,7 @@ static mscclppResult_t socketPollConnect(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock)
|
||||
{
|
||||
mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock) {
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketPollConnect: pass NULL socket");
|
||||
return mscclppInvalidArgument;
|
||||
@@ -618,12 +577,10 @@ mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock) {
|
||||
int sent = 0;
|
||||
MSCCLPPCHECK(socketProgress(MSCCLPP_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
|
||||
if (sent == 0)
|
||||
return mscclppSuccess;
|
||||
if (sent == 0) return mscclppSuccess;
|
||||
MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
|
||||
sent = 0;
|
||||
MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
|
||||
@@ -631,8 +588,7 @@ static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static mscclppResult_t socketProgressState(struct mscclppSocket* sock)
|
||||
{
|
||||
static mscclppResult_t socketProgressState(struct mscclppSocket* sock) {
|
||||
if (sock->state == mscclppSocketStateAccepting) {
|
||||
MSCCLPPCHECK(socketTryAccept(sock));
|
||||
}
|
||||
@@ -668,8 +624,7 @@ static mscclppResult_t socketProgressState(struct mscclppSocket* sock)
|
||||
// return mscclppSuccess;
|
||||
// }
|
||||
|
||||
mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock)
|
||||
{
|
||||
mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock) {
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
@@ -686,8 +641,7 @@ mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock)
|
||||
|
||||
if (sock->state != mscclppSocketStateInitialized) {
|
||||
WARN("mscclppSocketConnect: wrong socket state %d", sock->state);
|
||||
if (sock->state == mscclppSocketStateError)
|
||||
return mscclppRemoteError;
|
||||
if (sock->state == mscclppSocketStateError) return mscclppRemoteError;
|
||||
return mscclppInternalError;
|
||||
}
|
||||
TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Connecting to socket %s", mscclppSocketToString(&sock->addr, line));
|
||||
@@ -701,25 +655,23 @@ mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock)
|
||||
(sock->state == mscclppSocketStateConnecting || sock->state == mscclppSocketStateConnectPolling ||
|
||||
sock->state == mscclppSocketStateConnected));
|
||||
|
||||
if (sock->abortFlag && *sock->abortFlag != 0)
|
||||
return mscclppInternalError;
|
||||
if (sock->abortFlag && *sock->abortFlag != 0) return mscclppInternalError;
|
||||
|
||||
switch (sock->state) {
|
||||
case mscclppSocketStateConnecting:
|
||||
case mscclppSocketStateConnectPolling:
|
||||
case mscclppSocketStateConnected:
|
||||
case mscclppSocketStateReady:
|
||||
return mscclppSuccess;
|
||||
case mscclppSocketStateError:
|
||||
return mscclppSystemError;
|
||||
default:
|
||||
WARN("mscclppSocketConnect: wrong socket state %d", sock->state);
|
||||
return mscclppInternalError;
|
||||
case mscclppSocketStateConnecting:
|
||||
case mscclppSocketStateConnectPolling:
|
||||
case mscclppSocketStateConnected:
|
||||
case mscclppSocketStateReady:
|
||||
return mscclppSuccess;
|
||||
case mscclppSocketStateError:
|
||||
return mscclppSystemError;
|
||||
default:
|
||||
WARN("mscclppSocketConnect: wrong socket state %d", sock->state);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSocket* listenSock)
|
||||
{
|
||||
mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSocket* listenSock) {
|
||||
mscclppResult_t ret = mscclppSuccess;
|
||||
|
||||
if (listenSock == NULL || sock == NULL) {
|
||||
@@ -747,22 +699,21 @@ mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSo
|
||||
} while (sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
|
||||
(sock->state == mscclppSocketStateAccepting || sock->state == mscclppSocketStateAccepted));
|
||||
|
||||
if (sock->abortFlag && *sock->abortFlag != 0)
|
||||
return mscclppInternalError;
|
||||
if (sock->abortFlag && *sock->abortFlag != 0) return mscclppInternalError;
|
||||
|
||||
switch (sock->state) {
|
||||
case mscclppSocketStateAccepting:
|
||||
case mscclppSocketStateAccepted:
|
||||
case mscclppSocketStateReady:
|
||||
ret = mscclppSuccess;
|
||||
break;
|
||||
case mscclppSocketStateError:
|
||||
ret = mscclppSystemError;
|
||||
break;
|
||||
default:
|
||||
WARN("mscclppSocketAccept: wrong socket state %d", sock->state);
|
||||
ret = mscclppInternalError;
|
||||
break;
|
||||
case mscclppSocketStateAccepting:
|
||||
case mscclppSocketStateAccepted:
|
||||
case mscclppSocketStateReady:
|
||||
ret = mscclppSuccess;
|
||||
break;
|
||||
case mscclppSocketStateError:
|
||||
ret = mscclppSystemError;
|
||||
break;
|
||||
default:
|
||||
WARN("mscclppSocketAccept: wrong socket state %d", sock->state);
|
||||
ret = mscclppInternalError;
|
||||
break;
|
||||
}
|
||||
|
||||
exit:
|
||||
@@ -770,12 +721,10 @@ exit:
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketInit(struct mscclppSocket* sock, const mscclppSocketAddress* addr, uint64_t magic,
|
||||
enum mscclppSocketType type, volatile uint32_t* abortFlag, int asyncFlag)
|
||||
{
|
||||
enum mscclppSocketType type, volatile uint32_t* abortFlag, int asyncFlag) {
|
||||
mscclppResult_t ret = mscclppSuccess;
|
||||
|
||||
if (sock == NULL)
|
||||
goto exit;
|
||||
if (sock == NULL) goto exit;
|
||||
sock->connectRetries = 0;
|
||||
sock->acceptRetries = 0;
|
||||
sock->abortFlag = abortFlag;
|
||||
@@ -824,8 +773,7 @@ fail:
|
||||
goto exit;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset)
|
||||
{
|
||||
mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) {
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketProgress: pass NULL socket");
|
||||
return mscclppInvalidArgument;
|
||||
@@ -843,8 +791,7 @@ mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void*
|
||||
// return mscclppSuccess;
|
||||
// }
|
||||
|
||||
mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int size)
|
||||
{
|
||||
mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketSend: pass NULL socket");
|
||||
@@ -858,8 +805,7 @@ mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int siz
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int size)
|
||||
{
|
||||
mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
if (sock == NULL) {
|
||||
WARN("mscclppSocketRecv: pass NULL socket");
|
||||
@@ -888,11 +834,9 @@ mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int siz
|
||||
// return mscclppSuccess;
|
||||
// }
|
||||
|
||||
mscclppResult_t mscclppSocketClose(struct mscclppSocket* sock)
|
||||
{
|
||||
mscclppResult_t mscclppSocketClose(struct mscclppSocket* sock) {
|
||||
if (sock != NULL) {
|
||||
if (sock->fd >= 0)
|
||||
close(sock->fd);
|
||||
if (sock->fd >= 0) close(sock->fd);
|
||||
sock->state = mscclppSocketStateClosed;
|
||||
sock->fd = -1;
|
||||
}
|
||||
|
||||
39
src/c_style_remnants.cc
Normal file
39
src/c_style_remnants.cc
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "api.h"
|
||||
#include "config.h"
|
||||
#include "debug.h"
|
||||
#include "mscclpp.h"
|
||||
|
||||
MSCCLPP_API void mscclppDefaultLogHandler(const char* msg) { mscclppDebugDefaultLogHandler(msg); }
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) {
|
||||
return mscclppDebugSetLogHandler(handler);
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) {
|
||||
mscclppConfig* config = mscclppConfig::getInstance();
|
||||
config->setBootstrapConnectionTimeoutConfig(timeout);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API const char* mscclppGetErrorString(mscclppResult_t code) {
|
||||
switch (code) {
|
||||
case mscclppSuccess:
|
||||
return "no error";
|
||||
case mscclppUnhandledCudaError:
|
||||
return "unhandled cuda error";
|
||||
case mscclppSystemError:
|
||||
return "unhandled system error";
|
||||
case mscclppInternalError:
|
||||
return "internal error";
|
||||
case mscclppInvalidArgument:
|
||||
return "invalid argument";
|
||||
case mscclppInvalidUsage:
|
||||
return "invalid usage";
|
||||
case mscclppRemoteError:
|
||||
return "remote process exited or there was a network error";
|
||||
case mscclppInProgress:
|
||||
return "MSCCL++ operation in progress";
|
||||
default:
|
||||
return "unknown result code";
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "channel.hpp"
|
||||
#include <mscclpp/channel.hpp>
|
||||
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "debug.h"
|
||||
@@ -8,21 +9,19 @@ namespace mscclpp {
|
||||
namespace channel {
|
||||
|
||||
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator)
|
||||
: communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); })
|
||||
{
|
||||
: communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
int cudaDevice;
|
||||
CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void DeviceChannelService::bindThread()
|
||||
{
|
||||
MSCCLPP_API_CPP void DeviceChannelService::bindThread() {
|
||||
if (deviceNumaNode >= 0) {
|
||||
MSCCLPPTHROW(numaBind(deviceNumaNode));
|
||||
INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
#include "communicator.hpp"
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <sstream>
|
||||
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "comm.h"
|
||||
#include "communicator.hpp"
|
||||
#include "connection.hpp"
|
||||
#include "debug.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap)
|
||||
{
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {
|
||||
rankToHash_.resize(bootstrap->getNranks());
|
||||
auto hostHash = getHostHash();
|
||||
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
|
||||
@@ -21,13 +20,9 @@ Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(
|
||||
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
|
||||
}
|
||||
|
||||
Communicator::Impl::~Impl()
|
||||
{
|
||||
ibContexts_.clear();
|
||||
}
|
||||
Communicator::Impl::~Impl() { ibContexts_.clear(); }
|
||||
|
||||
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport)
|
||||
{
|
||||
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts_.find(ibTransport);
|
||||
if (it == ibContexts_.end()) {
|
||||
@@ -42,29 +37,20 @@ IbCtx* Communicator::Impl::getIbContext(Transport ibTransport)
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
: pimpl(std::make_unique<Impl>(bootstrap))
|
||||
{
|
||||
}
|
||||
: pimpl(std::make_unique<Impl>(bootstrap)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper()
|
||||
{
|
||||
return pimpl->bootstrap_;
|
||||
}
|
||||
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper() { return pimpl->bootstrap_; }
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports)
|
||||
{
|
||||
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
|
||||
return RegisteredMemory(
|
||||
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
|
||||
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
|
||||
}
|
||||
|
||||
struct MemorySender : public Setuppable
|
||||
{
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
struct MemorySender : public Setuppable {
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
|
||||
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
|
||||
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
|
||||
}
|
||||
|
||||
@@ -73,19 +59,14 @@ struct MemorySender : public Setuppable
|
||||
int tag_;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag)
|
||||
{
|
||||
addSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
|
||||
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) {
|
||||
onSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
|
||||
}
|
||||
|
||||
struct MemoryReceiver : public Setuppable
|
||||
{
|
||||
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
struct MemoryReceiver : public Setuppable {
|
||||
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
|
||||
std::vector<char> data;
|
||||
bootstrap->recv(data, remoteRank_, tag_);
|
||||
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
|
||||
@@ -96,15 +77,13 @@ struct MemoryReceiver : public Setuppable
|
||||
int tag_;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag)
|
||||
{
|
||||
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag) {
|
||||
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
|
||||
addSetup(memoryReceiver);
|
||||
onSetup(memoryReceiver);
|
||||
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport)
|
||||
{
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) {
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport == Transport::CudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
@@ -114,7 +93,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"
|
||||
<< " != " << pimpl->bootstrap_->getRank() << "(" << std::hex
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
|
||||
throw mscclpp::Error(ss.str(), mscclppInternalError);
|
||||
throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
|
||||
conn = cudaIpcConn;
|
||||
@@ -128,20 +107,18 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else {
|
||||
throw mscclpp::Error("Unsupported transport", mscclppInvalidArgument);
|
||||
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
|
||||
}
|
||||
pimpl->connections_.push_back(conn);
|
||||
addSetup(conn);
|
||||
onSetup(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::addSetup(std::shared_ptr<Setuppable> setuppable)
|
||||
{
|
||||
MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr<Setuppable> setuppable) {
|
||||
pimpl->toSetup_.push_back(setuppable);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::setup()
|
||||
{
|
||||
MSCCLPP_API_CPP void Communicator::setup() {
|
||||
for (auto& setuppable : pimpl->toSetup_) {
|
||||
setuppable->beginSetup(pimpl->bootstrap_);
|
||||
}
|
||||
@@ -151,4 +128,4 @@ MSCCLPP_API_CPP void Communicator::setup()
|
||||
pimpl->toSetup_.clear();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -2,17 +2,8 @@
|
||||
|
||||
mscclppConfig mscclppConfig::_instance;
|
||||
|
||||
mscclppConfig* mscclppConfig::getInstance()
|
||||
{
|
||||
return &_instance;
|
||||
}
|
||||
mscclppConfig* mscclppConfig::getInstance() { return &_instance; }
|
||||
|
||||
time_t mscclppConfig::getBootstrapConnectionTimeoutConfig()
|
||||
{
|
||||
return bootstrapConnectionTimeout;
|
||||
}
|
||||
time_t mscclppConfig::getBootstrapConnectionTimeoutConfig() { return bootstrapConnectionTimeout; }
|
||||
|
||||
void mscclppConfig::setBootstrapConnectionTimeoutConfig(time_t timeout)
|
||||
{
|
||||
bootstrapConnectionTimeout = timeout;
|
||||
}
|
||||
void mscclppConfig::setBootstrapConnectionTimeoutConfig(time_t timeout) { bootstrapConnectionTimeout = timeout; }
|
||||
|
||||
@@ -1,68 +1,47 @@
|
||||
#include "connection.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "checks.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "npkit/npkit.h"
|
||||
#include "registered_memory.hpp"
|
||||
#include "utils.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, Transport transport)
|
||||
{
|
||||
void validateTransport(RegisteredMemory mem, Transport transport) {
|
||||
if (!mem.transports().has(transport)) {
|
||||
throw Error("RegisteredMemory does not support transport", mscclppInvalidArgument);
|
||||
throw Error("RegisteredMemory does not support this transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection
|
||||
|
||||
std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(RegisteredMemory& mem)
|
||||
{
|
||||
return mem.pimpl;
|
||||
}
|
||||
std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(RegisteredMemory& mem) { return mem.pimpl; }
|
||||
|
||||
// ConnectionBase
|
||||
|
||||
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
int ConnectionBase::remoteRank()
|
||||
{
|
||||
return remoteRank_;
|
||||
}
|
||||
int ConnectionBase::remoteRank() { return remoteRank_; }
|
||||
|
||||
int ConnectionBase::tag()
|
||||
{
|
||||
return tag_;
|
||||
}
|
||||
int ConnectionBase::tag() { return tag_; }
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag)
|
||||
{
|
||||
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag) {
|
||||
CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaIpcConnection::~CudaIpcConnection()
|
||||
{
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); }
|
||||
|
||||
Transport CudaIpcConnection::transport()
|
||||
{
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
Transport CudaIpcConnection::transport() { return Transport::CudaIpc; }
|
||||
|
||||
Transport CudaIpcConnection::remoteTransport()
|
||||
{
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
Transport CudaIpcConnection::remoteTransport() { return Transport::CudaIpc; }
|
||||
|
||||
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
uint64_t size) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
@@ -75,8 +54,7 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::flush()
|
||||
{
|
||||
void CudaIpcConnection::flush() {
|
||||
CUDATHROW(cudaStreamSynchronize(stream));
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
}
|
||||
@@ -84,34 +62,29 @@ void CudaIpcConnection::flush()
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl)
|
||||
: ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0)
|
||||
{
|
||||
: ConnectionBase(remoteRank, tag),
|
||||
transport_(transport),
|
||||
remoteTransport_(Transport::Unknown),
|
||||
numSignaledSends(0) {
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
}
|
||||
|
||||
Transport IBConnection::transport()
|
||||
{
|
||||
return transport_;
|
||||
}
|
||||
Transport IBConnection::transport() { return transport_; }
|
||||
|
||||
Transport IBConnection::remoteTransport()
|
||||
{
|
||||
return remoteTransport_;
|
||||
}
|
||||
Transport IBConnection::remoteTransport() { return remoteTransport_; }
|
||||
|
||||
void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
uint64_t size) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
|
||||
if (dstTransportInfo.ibLocal) {
|
||||
throw Error("dst is local, which is not supported", mscclppInvalidArgument);
|
||||
throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage);
|
||||
}
|
||||
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
|
||||
if (!srcTransportInfo.ibLocal) {
|
||||
throw Error("src is remote, which is not supported", mscclppInvalidArgument);
|
||||
throw Error("src is remote, which is not supported", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
@@ -126,8 +99,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void IBConnection::flush()
|
||||
{
|
||||
void IBConnection::flush() {
|
||||
Timer timer;
|
||||
while (numSignaledSends) {
|
||||
int wcNum = qp->pollCq();
|
||||
@@ -137,9 +109,9 @@ void IBConnection::flush()
|
||||
|
||||
auto elapsed = timer.elapsed();
|
||||
if (elapsed > MSCCLPP_POLLING_WAIT) {
|
||||
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " +
|
||||
std::to_string(numSignaledSends) + " signals",
|
||||
mscclppInternalError);
|
||||
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
|
||||
std::to_string(numSignaledSends) + " signals",
|
||||
ErrorCode::InternalError);
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
@@ -154,8 +126,7 @@ void IBConnection::flush()
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
{
|
||||
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
std::vector<char> ibQpTransport;
|
||||
std::copy_n(reinterpret_cast<char*>(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport));
|
||||
std::copy_n(reinterpret_cast<char*>(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport));
|
||||
@@ -163,8 +134,7 @@ void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
{
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
std::vector<char> ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport));
|
||||
bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
|
||||
@@ -179,4 +149,4 @@ void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
qp->rts();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
65
src/debug.cc
65
src/debug.cc
@@ -5,6 +5,7 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdlib.h>
|
||||
@@ -15,8 +16,8 @@ int mscclppDebugLevel = -1;
|
||||
static int pid = -1;
|
||||
static char hostname[1024];
|
||||
thread_local int mscclppDebugNoWarn = 0;
|
||||
char mscclppLastError[1024] = ""; // Global string for the last error in human readable form
|
||||
uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT
|
||||
char mscclppLastError[1024] = ""; // Global string for the last error in human readable form
|
||||
uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT
|
||||
FILE* mscclppDebugFile = stdout;
|
||||
mscclppLogHandler_t mscclppDebugLogHandler = NULL;
|
||||
pthread_mutex_t mscclppDebugLock = PTHREAD_MUTEX_INITIALIZER;
|
||||
@@ -24,13 +25,9 @@ std::chrono::steady_clock::time_point mscclppEpoch;
|
||||
|
||||
static __thread int tid = -1;
|
||||
|
||||
void mscclppDebugDefaultLogHandler(const char* msg)
|
||||
{
|
||||
fwrite(msg, 1, strlen(msg), mscclppDebugFile);
|
||||
}
|
||||
void mscclppDebugDefaultLogHandler(const char* msg) { fwrite(msg, 1, strlen(msg), mscclppDebugFile); }
|
||||
|
||||
void mscclppDebugInit()
|
||||
{
|
||||
void mscclppDebugInit() {
|
||||
pthread_mutex_lock(&mscclppDebugLock);
|
||||
if (mscclppDebugLevel != -1) {
|
||||
pthread_mutex_unlock(&mscclppDebugLock);
|
||||
@@ -121,33 +118,32 @@ void mscclppDebugInit()
|
||||
continue;
|
||||
}
|
||||
switch (mscclppDebugFileEnv[c++]) {
|
||||
case '%': // Double %
|
||||
*dfn++ = '%';
|
||||
break;
|
||||
case 'h': // %h = hostname
|
||||
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
|
||||
break;
|
||||
case 'p': // %p = pid
|
||||
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
|
||||
break;
|
||||
default: // Echo everything we don't understand
|
||||
*dfn++ = '%';
|
||||
*dfn++ = mscclppDebugFileEnv[c - 1];
|
||||
break;
|
||||
case '%': // Double %
|
||||
*dfn++ = '%';
|
||||
break;
|
||||
case 'h': // %h = hostname
|
||||
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
|
||||
break;
|
||||
case 'p': // %p = pid
|
||||
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
|
||||
break;
|
||||
default: // Echo everything we don't understand
|
||||
*dfn++ = '%';
|
||||
*dfn++ = mscclppDebugFileEnv[c - 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
*dfn = '\0';
|
||||
if (debugFn[0] != '\0') {
|
||||
FILE* file = fopen(debugFn, "w");
|
||||
if (file != nullptr) {
|
||||
setbuf(file, nullptr); // disable buffering
|
||||
setbuf(file, nullptr); // disable buffering
|
||||
mscclppDebugFile = file;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mscclppDebugLogHandler == NULL)
|
||||
mscclppDebugLogHandler = mscclppDefaultLogHandler;
|
||||
if (mscclppDebugLogHandler == NULL) mscclppDebugLogHandler = mscclppDefaultLogHandler;
|
||||
|
||||
mscclppEpoch = std::chrono::steady_clock::now();
|
||||
__atomic_store_n(&mscclppDebugLevel, tempNcclDebugLevel, __ATOMIC_RELEASE);
|
||||
@@ -159,10 +155,8 @@ void mscclppDebugInit()
|
||||
* they can share the debugging mechanisms and output files
|
||||
*/
|
||||
void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char* filefunc, int line, const char* fmt,
|
||||
...)
|
||||
{
|
||||
if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1)
|
||||
mscclppDebugInit();
|
||||
...) {
|
||||
if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) mscclppDebugInit();
|
||||
if (mscclppDebugNoWarn != 0 && level == MSCCLPP_LOG_WARN) {
|
||||
level = MSCCLPP_LOG_INFO;
|
||||
flags = mscclppDebugNoWarn;
|
||||
@@ -176,8 +170,7 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char
|
||||
va_end(vargs);
|
||||
pthread_mutex_unlock(&mscclppDebugLock);
|
||||
}
|
||||
if (mscclppDebugLevel < level || ((flags & mscclppDebugMask) == 0))
|
||||
return;
|
||||
if (mscclppDebugLevel < level || ((flags & mscclppDebugMask) == 0)) return;
|
||||
|
||||
if (tid == -1) {
|
||||
tid = syscall(SYS_gettid);
|
||||
@@ -218,20 +211,16 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char
|
||||
}
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler)
|
||||
{
|
||||
if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1)
|
||||
mscclppDebugInit();
|
||||
if (handler == NULL)
|
||||
return mscclppInvalidArgument;
|
||||
mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler) {
|
||||
if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) mscclppDebugInit();
|
||||
if (handler == NULL) return mscclppInvalidArgument;
|
||||
pthread_mutex_lock(&mscclppDebugLock);
|
||||
mscclppDebugLogHandler = handler;
|
||||
pthread_mutex_unlock(&mscclppDebugLock);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
void mscclppSetThreadName(pthread_t thread, const char* fmt, ...)
|
||||
{
|
||||
void mscclppSetThreadName(pthread_t thread, const char* fmt, ...) {
|
||||
// pthread_setname_np is nonstandard GNU extension
|
||||
// needs the following feature test macro
|
||||
#ifdef _GNU_SOURCE
|
||||
|
||||
75
src/epoch.cc
75
src/epoch.cc
@@ -1,32 +1,69 @@
|
||||
#include "epoch.hpp"
|
||||
#include <mscclpp/epoch.hpp>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: connection_(connection)
|
||||
{
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
|
||||
BaseEpoch::BaseEpoch(std::shared_ptr<Connection> connection) : connection_(connection) {}
|
||||
|
||||
localEpochIdsRegMem_ =
|
||||
communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
|
||||
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag());
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
|
||||
void BaseEpoch::setup(Communicator& communicator) {
|
||||
localEpochIdsRegMem_ = communicator.registerMemory(epochIds_, sizeof(epochIds_), connection_->transport());
|
||||
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection_->remoteRank(), connection_->tag());
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection_->remoteRank(), connection_->tag());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Epoch::~Epoch()
|
||||
{
|
||||
mscclppCudaFree(device_.epochIds_);
|
||||
mscclppCudaFree(device_.expectedInboundEpochId_);
|
||||
void BaseEpoch::signal() {
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica), localEpochIdsRegMem_,
|
||||
offsetof(EpochIds, outbound), sizeof(epochIds_));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Epoch::signal()
|
||||
{
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_,
|
||||
offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
|
||||
MSCCLPP_API_CPP DeviceEpoch::DeviceEpoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: BaseEpoch(connection) {
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&epochIds_, 1));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&expectedInboundEpochId_, 1));
|
||||
setup(communicator);
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
MSCCLPP_API_CPP DeviceEpoch::~DeviceEpoch() {
|
||||
mscclppCudaFree(epochIds_);
|
||||
mscclppCudaFree(expectedInboundEpochId_);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void DeviceEpoch::signal() { BaseEpoch::signal(); }
|
||||
|
||||
MSCCLPP_API_CPP DeviceEpoch::DeviceHandle DeviceEpoch::deviceHandle() {
|
||||
DeviceEpoch::DeviceHandle device;
|
||||
device.epochIds = epochIds_;
|
||||
device.expectedInboundEpochId = expectedInboundEpochId_;
|
||||
return device;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostEpoch::HostEpoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: BaseEpoch(connection) {
|
||||
if (connection->transport() == Transport::CudaIpc) {
|
||||
throw Error("HostEpoch cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
epochIds_ = new EpochIds();
|
||||
expectedInboundEpochId_ = new uint64_t();
|
||||
setup(communicator);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostEpoch::~HostEpoch() {
|
||||
delete epochIds_;
|
||||
delete expectedInboundEpochId_;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostEpoch::increamentAndSignal() {
|
||||
*(volatile uint64_t*)&(epochIds_->outbound) += 1;
|
||||
signal();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostEpoch::wait() {
|
||||
(*expectedInboundEpochId_) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica) < (*expectedInboundEpochId_))
|
||||
;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -1,30 +1,19 @@
|
||||
#include "errors.hpp"
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
BaseError::BaseError(std::string message, int errorCode) : std::runtime_error(message), errorCode_(errorCode)
|
||||
{
|
||||
}
|
||||
BaseError::BaseError(std::string message, int errorCode) : std::runtime_error(message), errorCode_(errorCode) {}
|
||||
|
||||
int BaseError::getErrorCode() const
|
||||
{
|
||||
return errorCode_;
|
||||
}
|
||||
int BaseError::getErrorCode() const { return errorCode_; }
|
||||
|
||||
Error::Error(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
MSCCLPP_API_CPP Error::Error(std::string message, ErrorCode errorCode) : BaseError(message, -1) {}
|
||||
|
||||
CudaError::CudaError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
MSCCLPP_API_CPP CudaError::CudaError(std::string message, int errorCode) : BaseError(message, errorCode) {}
|
||||
|
||||
CuError::CuError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
MSCCLPP_API_CPP CuError::CuError(std::string message, int errorCode) : BaseError(message, errorCode) {}
|
||||
|
||||
IbError::IbError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
MSCCLPP_API_CPP IbError::IbError(std::string message, int errorCode) : BaseError(message, errorCode) {}
|
||||
|
||||
}; // namespace mscclpp
|
||||
}; // namespace mscclpp
|
||||
|
||||
35
src/fifo.cc
35
src/fifo.cc
@@ -1,15 +1,16 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <emmintrin.h>
|
||||
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include <emmintrin.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostProxyFifo::Impl
|
||||
{
|
||||
struct HostProxyFifo::Impl {
|
||||
DeviceProxyFifo deviceFifo;
|
||||
|
||||
// allocated on the host. Only accessed by the host. This is a copy of the
|
||||
@@ -25,8 +26,7 @@ struct HostProxyFifo::Impl
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo()
|
||||
{
|
||||
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() {
|
||||
pimpl = std::make_unique<Impl>();
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1));
|
||||
MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE));
|
||||
@@ -35,28 +35,24 @@ MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo()
|
||||
pimpl->hostTail = 0;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo()
|
||||
{
|
||||
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() {
|
||||
mscclppCudaFree(pimpl->deviceFifo.head);
|
||||
mscclppCudaHostFree(pimpl->deviceFifo.triggers);
|
||||
mscclppCudaFree(pimpl->deviceFifo.tailReplica);
|
||||
cudaStreamDestroy(pimpl->stream);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger)
|
||||
{
|
||||
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) {
|
||||
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
_mm_store_si128((__m128i*)trigger, xmm0);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::pop()
|
||||
{
|
||||
MSCCLPP_API_CPP void HostProxyFifo::pop() {
|
||||
*(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
(pimpl->hostTail)++;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync)
|
||||
{
|
||||
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) {
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
@@ -67,9 +63,6 @@ MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync)
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo()
|
||||
{
|
||||
return pimpl->deviceFifo;
|
||||
}
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() { return pimpl->deviceFifo; }
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
75
src/gdr.cc
75
src/gdr.cc
@@ -1,75 +0,0 @@
|
||||
#include "gdr.h"
|
||||
|
||||
// Used to make the GDR library calls thread safe
|
||||
pthread_mutex_t gdrLock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
gdr_t wrap_gdr_open(void)
|
||||
{
|
||||
return gdr_open();
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_close(gdr_t g)
|
||||
{
|
||||
int ret = gdr_close(g);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_close() failed: %d", ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_pin_buffer(gdr_t g, unsigned long addr, size_t size, uint64_t p2p_token, uint32_t va_space,
|
||||
gdr_mh_t* handle)
|
||||
{
|
||||
int ret;
|
||||
GDRLOCKCALL(gdr_pin_buffer(g, addr, size, p2p_token, va_space, handle), ret);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_pin_buffer(addr %lx, size %zi) failed: %d", addr, size, ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_unpin_buffer(gdr_t g, gdr_mh_t handle)
|
||||
{
|
||||
int ret;
|
||||
GDRLOCKCALL(gdr_unpin_buffer(g, handle), ret);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_unpin_buffer(handle %lx) failed: %d", handle.h, ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_get_info(gdr_t g, gdr_mh_t handle, gdr_info_t* info)
|
||||
{
|
||||
int ret;
|
||||
GDRLOCKCALL(gdr_get_info(g, handle, info), ret);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_get_info(handle %lx) failed: %d", handle.h, ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_map(gdr_t g, gdr_mh_t handle, void** va, size_t size)
|
||||
{
|
||||
int ret;
|
||||
GDRLOCKCALL(gdr_map(g, handle, va, size), ret);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_map(handle %lx, size %zi) failed: %d", handle.h, size, ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t wrap_gdr_unmap(gdr_t g, gdr_mh_t handle, void* va, size_t size)
|
||||
{
|
||||
int ret;
|
||||
GDRLOCKCALL(gdr_unmap(g, handle, va, size), ret);
|
||||
if (ret != 0) {
|
||||
WARN("gdr_unmap(handle %lx, va %p, size %zi) failed: %d", handle.h, va, size, ret);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
221
src/ib.cc
221
src/ib.cc
@@ -1,24 +1,26 @@
|
||||
#include "ib.hpp"
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <malloc.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <malloc.h>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <sstream>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <string>
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
{
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
if (size == 0) {
|
||||
throw std::invalid_argument("invalid size: " + std::to_string(size));
|
||||
}
|
||||
@@ -29,9 +31,9 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
struct ibv_mr* _mr =
|
||||
ibv_reg_mr(_pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
struct ibv_mr* _mr = ibv_reg_mr(
|
||||
_pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
if (_mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
@@ -41,31 +43,20 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
this->size = pages * pageSize;
|
||||
}
|
||||
|
||||
IbMr::~IbMr()
|
||||
{
|
||||
ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr));
|
||||
}
|
||||
IbMr::~IbMr() { ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr)); }
|
||||
|
||||
IbMrInfo IbMr::getInfo() const
|
||||
{
|
||||
IbMrInfo IbMr::getInfo() const {
|
||||
IbMrInfo info;
|
||||
info.addr = reinterpret_cast<uint64_t>(this->buff);
|
||||
info.rkey = reinterpret_cast<struct ibv_mr*>(this->mr)->rkey;
|
||||
return info;
|
||||
}
|
||||
|
||||
const void* IbMr::getBuff() const
|
||||
{
|
||||
return this->buff;
|
||||
}
|
||||
const void* IbMr::getBuff() const { return this->buff; }
|
||||
|
||||
uint32_t IbMr::getLkey() const
|
||||
{
|
||||
return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey;
|
||||
}
|
||||
uint32_t IbMr::getLkey() const { return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey; }
|
||||
|
||||
IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
{
|
||||
IbQp::IbQp(void* ctx, void* pd, int port) {
|
||||
struct ibv_context* _ctx = reinterpret_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
|
||||
@@ -137,8 +128,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_wc**>(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM));
|
||||
}
|
||||
|
||||
IbQp::~IbQp()
|
||||
{
|
||||
IbQp::~IbQp() {
|
||||
ibv_destroy_qp(reinterpret_cast<struct ibv_qp*>(this->qp));
|
||||
ibv_destroy_cq(reinterpret_cast<struct ibv_cq*>(this->cq));
|
||||
std::free(this->wrs);
|
||||
@@ -146,8 +136,7 @@ IbQp::~IbQp()
|
||||
std::free(this->wcs);
|
||||
}
|
||||
|
||||
void IbQp::rtr(const IbQpInfo& info)
|
||||
{
|
||||
void IbQp::rtr(const IbQpInfo& info) {
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTR;
|
||||
@@ -167,13 +156,13 @@ void IbQp::rtr(const IbQpInfo& info)
|
||||
} else {
|
||||
qp_attr.ah_attr.is_global = 0;
|
||||
}
|
||||
qp_attr.ah_attr.dlid = info->lid;
|
||||
qp_attr.ah_attr.dlid = info.lid;
|
||||
qp_attr.ah_attr.sl = 0;
|
||||
qp_attr.ah_attr.src_path_bits = 0;
|
||||
qp_attr.ah_attr.port_num = info.port;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
@@ -181,8 +170,7 @@ void IbQp::rtr(const IbQpInfo& info)
|
||||
}
|
||||
}
|
||||
|
||||
void IbQp::rts()
|
||||
{
|
||||
void IbQp::rts() {
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTS;
|
||||
@@ -191,9 +179,9 @@ void IbQp::rts()
|
||||
qp_attr.rnr_retry = 7;
|
||||
qp_attr.sq_psn = 0;
|
||||
qp_attr.max_rd_atomic = 1;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
int ret = ibv_modify_qp(
|
||||
reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
@@ -202,8 +190,7 @@ void IbQp::rts()
|
||||
}
|
||||
|
||||
int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled)
|
||||
{
|
||||
uint64_t dstOffset, bool signaled) {
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
}
|
||||
@@ -232,8 +219,7 @@ int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_
|
||||
}
|
||||
|
||||
int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
{
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData) {
|
||||
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
@@ -241,8 +227,7 @@ int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size,
|
||||
return wrn;
|
||||
}
|
||||
|
||||
void IbQp::postSend()
|
||||
{
|
||||
void IbQp::postSend() {
|
||||
if (this->wrn == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -257,8 +242,7 @@ void IbQp::postSend()
|
||||
this->wrn = 0;
|
||||
}
|
||||
|
||||
void IbQp::postRecv(uint64_t wrId)
|
||||
{
|
||||
void IbQp::postRecv(uint64_t wrId) {
|
||||
struct ibv_recv_wr wr, *bad_wr;
|
||||
wr.wr_id = wrId;
|
||||
wr.sg_list = nullptr;
|
||||
@@ -272,24 +256,16 @@ void IbQp::postRecv(uint64_t wrId)
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::pollCq()
|
||||
{
|
||||
int IbQp::pollCq() {
|
||||
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM,
|
||||
reinterpret_cast<struct ibv_wc*>(this->wcs));
|
||||
}
|
||||
|
||||
IbQpInfo& IbQp::getInfo()
|
||||
{
|
||||
return this->info;
|
||||
}
|
||||
IbQpInfo& IbQp::getInfo() { return this->info; }
|
||||
|
||||
const void* IbQp::getWc(int idx) const
|
||||
{
|
||||
return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx];
|
||||
}
|
||||
const void* IbQp::getWc(int idx) const { return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx]; }
|
||||
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName)
|
||||
{
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
@@ -312,8 +288,7 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName)
|
||||
}
|
||||
}
|
||||
|
||||
IbCtx::~IbCtx()
|
||||
{
|
||||
IbCtx::~IbCtx() {
|
||||
this->mrs.clear();
|
||||
this->qps.clear();
|
||||
if (this->pd != nullptr) {
|
||||
@@ -324,8 +299,7 @@ IbCtx::~IbCtx()
|
||||
}
|
||||
}
|
||||
|
||||
bool IbCtx::isPortUsable(int port) const
|
||||
{
|
||||
bool IbCtx::isPortUsable(int port) const {
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
@@ -336,8 +310,7 @@ bool IbCtx::isPortUsable(int port) const
|
||||
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
|
||||
}
|
||||
|
||||
int IbCtx::getAnyActivePort() const
|
||||
{
|
||||
int IbCtx::getAnyActivePort() const {
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
@@ -352,70 +325,63 @@ int IbCtx::getAnyActivePort() const
|
||||
return -1;
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int port /*=-1*/)
|
||||
{
|
||||
IbQp* IbCtx::createQp(int port /*=-1*/) {
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw mscclpp::Error("No active port found", mscclppInternalError);
|
||||
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), mscclppInternalError);
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
const IbMr* IbCtx::registerMr(void* buff, std::size_t size)
|
||||
{
|
||||
const IbMr* IbCtx::registerMr(void* buff, std::size_t size) {
|
||||
mrs.emplace_back(new IbMr(this->pd, buff, size));
|
||||
return mrs.back().get();
|
||||
}
|
||||
|
||||
const std::string& IbCtx::getDevName() const
|
||||
{
|
||||
return this->devName;
|
||||
}
|
||||
const std::string& IbCtx::getDevName() const { return this->devName; }
|
||||
|
||||
MSCCLPP_API_CPP int getIBDeviceCount()
|
||||
{
|
||||
MSCCLPP_API_CPP int getIBDeviceCount() {
|
||||
int num;
|
||||
ibv_get_device_list(&num);
|
||||
return num;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
|
||||
{
|
||||
MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
int ibTransportIndex;
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case Transport::IB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case Transport::IB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case Transport::IB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case Transport::IB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case Transport::IB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case Transport::IB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case Transport::IB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case Transport::IB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Not an IB transport");
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case Transport::IB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case Transport::IB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case Transport::IB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case Transport::IB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case Transport::IB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case Transport::IB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case Transport::IB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case Transport::IB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Not an IB transport");
|
||||
}
|
||||
if (ibTransportIndex >= num) {
|
||||
throw std::out_of_range("IB transport out of range");
|
||||
@@ -423,35 +389,34 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName)
|
||||
{
|
||||
MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName) {
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (ibDeviceName == devices[i]->name) {
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return Transport::IB0;
|
||||
case 1:
|
||||
return Transport::IB1;
|
||||
case 2:
|
||||
return Transport::IB2;
|
||||
case 3:
|
||||
return Transport::IB3;
|
||||
case 4:
|
||||
return Transport::IB4;
|
||||
case 5:
|
||||
return Transport::IB5;
|
||||
case 6:
|
||||
return Transport::IB6;
|
||||
case 7:
|
||||
return Transport::IB7;
|
||||
default:
|
||||
throw std::out_of_range("IB device index out of range");
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return Transport::IB0;
|
||||
case 1:
|
||||
return Transport::IB1;
|
||||
case 2:
|
||||
return Transport::IB2;
|
||||
case 3:
|
||||
return Transport::IB3;
|
||||
case 4:
|
||||
return Transport::IB4;
|
||||
case 5:
|
||||
return Transport::IB5;
|
||||
case 6:
|
||||
return Transport::IB6;
|
||||
case 7:
|
||||
return Transport::IB7;
|
||||
default:
|
||||
throw std::out_of_range("IB device index out of range");
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::invalid_argument("IB device not found");
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -22,19 +22,19 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename X, typename Y, typename Z = decltype(X() + Y())> __host__ __device__ constexpr Z divUp(X x, Y y)
|
||||
{
|
||||
template <typename X, typename Y, typename Z = decltype(X() + Y())>
|
||||
__host__ __device__ constexpr Z divUp(X x, Y y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z = decltype(X() + Y())> __host__ __device__ constexpr Z roundUp(X x, Y y)
|
||||
{
|
||||
template <typename X, typename Y, typename Z = decltype(X() + Y())>
|
||||
__host__ __device__ constexpr Z roundUp(X x, Y y) {
|
||||
return (x + y - 1) - (x + y - 1) % y;
|
||||
}
|
||||
|
||||
// assumes second argument is a power of 2
|
||||
template <typename X, typename Z = decltype(X() + int())> __host__ __device__ constexpr Z alignUp(X x, int a)
|
||||
{
|
||||
template <typename X, typename Z = decltype(X() + int())>
|
||||
__host__ __device__ constexpr Z alignUp(X x, int a) {
|
||||
return (x + a - 1) & Z(-a);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,16 +7,17 @@
|
||||
#ifndef MSCCLPP_ALLOC_H_
|
||||
#define MSCCLPP_ALLOC_H_
|
||||
|
||||
#include "align.h"
|
||||
#include "checks.h"
|
||||
#include "mscclpp.h"
|
||||
#include "utils.h"
|
||||
#include <stdlib.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line)
|
||||
{
|
||||
#include "align.h"
|
||||
#include "checks.h"
|
||||
#include "mscclpp.h"
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
*ptr = nullptr;
|
||||
@@ -26,21 +27,19 @@ template <typename T> mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t
|
||||
memset(*ptr, 0, nelem * sizeof(T));
|
||||
finish:
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
if (*ptr == nullptr)
|
||||
WARN("Failed to CUDA host alloc %ld bytes", nelem * sizeof(T));
|
||||
if (*ptr == nullptr) WARN("Failed to CUDA host alloc %ld bytes", nelem * sizeof(T));
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
|
||||
return result;
|
||||
}
|
||||
#define mscclppCudaHostCalloc(...) mscclppCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
inline mscclppResult_t mscclppCudaHostFree(void* ptr)
|
||||
{
|
||||
inline mscclppResult_t mscclppCudaHostFree(void* ptr) {
|
||||
CUDACHECK(cudaFreeHost(ptr));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
|
||||
void* p = malloc(nelem * sizeof(T));
|
||||
if (p == NULL) {
|
||||
WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
|
||||
@@ -53,12 +52,10 @@ template <typename T> mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem,
|
||||
}
|
||||
#define mscclppCalloc(...) mscclppCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T> mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, size_t nelem)
|
||||
{
|
||||
if (nelem < oldNelem)
|
||||
return mscclppInternalError;
|
||||
if (nelem == oldNelem)
|
||||
return mscclppSuccess;
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, size_t nelem) {
|
||||
if (nelem < oldNelem) return mscclppInternalError;
|
||||
if (nelem == oldNelem) return mscclppSuccess;
|
||||
|
||||
T* oldp = *ptr;
|
||||
T* p = (T*)malloc(nelem * sizeof(T));
|
||||
@@ -75,8 +72,8 @@ template <typename T> mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, s
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nelem, const char* filefunc, int line)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
*ptr = nullptr;
|
||||
@@ -84,15 +81,14 @@ template <typename T> mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nel
|
||||
CUDACHECKGOTO(cudaMalloc(ptr, nelem * sizeof(T)), result, finish);
|
||||
finish:
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
if (*ptr == nullptr)
|
||||
WARN("Failed to CUDA malloc %ld bytes", nelem * sizeof(T));
|
||||
if (*ptr == nullptr) WARN("Failed to CUDA malloc %ld bytes", nelem * sizeof(T));
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
|
||||
return result;
|
||||
}
|
||||
#define mscclppCudaMalloc(...) mscclppCudaMallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
*ptr = nullptr;
|
||||
@@ -106,16 +102,15 @@ template <typename T> mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nel
|
||||
CUDACHECKGOTO(cudaStreamDestroy(stream), result, finish);
|
||||
finish:
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
if (*ptr == nullptr)
|
||||
WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T));
|
||||
if (*ptr == nullptr) WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T));
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
|
||||
return result;
|
||||
}
|
||||
#define mscclppCudaCalloc(...) mscclppCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t stream, const char* filefunc, int line)
|
||||
{
|
||||
mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t stream, const char* filefunc,
|
||||
int line) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
*ptr = nullptr;
|
||||
@@ -124,15 +119,14 @@ mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t
|
||||
CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem * sizeof(T), stream), result, finish);
|
||||
finish:
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
if (*ptr == nullptr)
|
||||
WARN("Failed to CUDA calloc async %ld bytes", nelem * sizeof(T));
|
||||
if (*ptr == nullptr) WARN("Failed to CUDA calloc async %ld bytes", nelem * sizeof(T));
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
|
||||
return result;
|
||||
}
|
||||
#define mscclppCudaCallocAsync(...) mscclppCudaCallocAsyncDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaMemcpy(T* dst, T* src, size_t nelem)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaMemcpy(T* dst, T* src, size_t nelem) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
@@ -147,8 +141,8 @@ finish:
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaMemcpyAsync(T* dst, T* src, size_t nelem, cudaStream_t stream)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaMemcpyAsync(T* dst, T* src, size_t nelem, cudaStream_t stream) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
@@ -158,8 +152,8 @@ finish:
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T> mscclppResult_t mscclppCudaFree(T* ptr)
|
||||
{
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppCudaFree(T* ptr) {
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
@@ -172,14 +166,12 @@ finish:
|
||||
// Allocate memory to be potentially ibv_reg_mr'd. This needs to be
|
||||
// allocated on separate pages as those pages will be marked DONTFORK
|
||||
// and if they are shared, that could cause a crash in a child process
|
||||
inline mscclppResult_t mscclppIbMallocDebug(void** ptr, size_t size, const char* filefunc, int line)
|
||||
{
|
||||
inline mscclppResult_t mscclppIbMallocDebug(void** ptr, size_t size, const char* filefunc, int line) {
|
||||
size_t page_size = sysconf(_SC_PAGESIZE);
|
||||
void* p;
|
||||
int size_aligned = ROUNDUP(size, page_size);
|
||||
int ret = posix_memalign(&p, page_size, size_aligned);
|
||||
if (ret != 0)
|
||||
return mscclppSystemError;
|
||||
if (ret != 0) return mscclppSystemError;
|
||||
memset(p, 0, size);
|
||||
*ptr = p;
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr);
|
||||
|
||||
@@ -4,4 +4,4 @@
|
||||
#define MSCCLPP_API extern "C" __attribute__((visibility("default")))
|
||||
#define MSCCLPP_API_CPP __attribute__((visibility("default")))
|
||||
|
||||
#endif // MSCCLPP_API_H_
|
||||
#endif // MSCCLPP_API_H_
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mscclpp.h"
|
||||
#include "socket.h"
|
||||
|
||||
#include "comm.h"
|
||||
|
||||
// ------------------- Old bootstrap headers: to be removed -------------------
|
||||
|
||||
struct mscclppBootstrapHandle
|
||||
{
|
||||
uint64_t magic;
|
||||
union mscclppSocketAddress addr;
|
||||
};
|
||||
mscclppResult_t bootstrapNetInit(const char* ip_port_pair = NULL);
|
||||
mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle);
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true,
|
||||
const char* ip_port_pair = NULL);
|
||||
mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm);
|
||||
mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||
mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
|
||||
mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size);
|
||||
mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag);
|
||||
mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size);
|
||||
mscclppResult_t bootstrapClose(void* commState);
|
||||
mscclppResult_t bootstrapAbort(void* commState);
|
||||
@@ -7,187 +7,182 @@
|
||||
#ifndef MSCCLPP_CHECKS_H_
|
||||
#define MSCCLPP_CHECKS_H_
|
||||
|
||||
#include "debug.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
// Check CUDA RT calls
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
return mscclppUnhandledCudaError; \
|
||||
} \
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
return mscclppUnhandledCudaError; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUDACHECKNORET(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
return; \
|
||||
} \
|
||||
#define CUDACHECKNORET(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
return; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUDACHECKGOTO(cmd, res, label) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
res = mscclppUnhandledCudaError; \
|
||||
goto label; \
|
||||
} \
|
||||
#define CUDACHECKGOTO(cmd, res, label) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
res = mscclppUnhandledCudaError; \
|
||||
goto label; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
// Report failure but clear error and continue
|
||||
#define CUDACHECKIGNORE(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
INFO(MSCCLPP_ALL, "%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
(void)cudaGetLastError(); \
|
||||
} \
|
||||
#define CUDACHECKIGNORE(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
INFO(MSCCLPP_ALL, "%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
(void)cudaGetLastError(); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#include <errno.h>
|
||||
// Check system calls
|
||||
#define SYSCHECK(call, name) \
|
||||
do { \
|
||||
int retval; \
|
||||
SYSCHECKVAL(call, name, retval); \
|
||||
#define SYSCHECK(call, name) \
|
||||
do { \
|
||||
int retval; \
|
||||
SYSCHECKVAL(call, name, retval); \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKVAL(call, name, retval) \
|
||||
do { \
|
||||
SYSCHECKSYNC(call, name, retval); \
|
||||
if (retval == -1) { \
|
||||
WARN("Call to " name " failed : %s", strerror(errno)); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
#define SYSCHECKVAL(call, name, retval) \
|
||||
do { \
|
||||
SYSCHECKSYNC(call, name, retval); \
|
||||
if (retval == -1) { \
|
||||
WARN("Call to " name " failed : %s", strerror(errno)); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SYSCHECKSYNC(call, name, retval) \
|
||||
do { \
|
||||
retval = call; \
|
||||
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
|
||||
INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
#define SYSCHECKSYNC(call, name, retval) \
|
||||
do { \
|
||||
retval = call; \
|
||||
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
|
||||
INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
} while (true)
|
||||
|
||||
#define SYSCHECKGOTO(statement, res, label) \
|
||||
do { \
|
||||
if ((statement) == -1) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
#define SYSCHECKGOTO(statement, res, label) \
|
||||
do { \
|
||||
if ((statement) == -1) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define NEQCHECK(statement, value) \
|
||||
do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
#define NEQCHECK(statement, value) \
|
||||
do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define NEQCHECKGOTO(statement, value, res, label) \
|
||||
do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
#define NEQCHECKGOTO(statement, value, res, label) \
|
||||
do { \
|
||||
if ((statement) != value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define EQCHECK(statement, value) \
|
||||
do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
#define EQCHECK(statement, value) \
|
||||
do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define EQCHECKGOTO(statement, value, res, label) \
|
||||
do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
#define EQCHECKGOTO(statement, value, res, label) \
|
||||
do { \
|
||||
if ((statement) == value) { \
|
||||
/* Print the back trace*/ \
|
||||
res = mscclppSystemError; \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
// Propagate errors up
|
||||
#define MSCCLPPCHECK(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (mscclppDebugNoWarn == 0) \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return res; \
|
||||
} \
|
||||
#define MSCCLPPCHECK(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return res; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define MSCCLPPCHECKGOTO(call, res, label) \
|
||||
do { \
|
||||
res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (mscclppDebugNoWarn == 0) \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
#define MSCCLPPCHECKGOTO(call, res, label) \
|
||||
do { \
|
||||
res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define MSCCLPPWAIT(call, cond, abortFlagPtr) \
|
||||
do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
if (mscclppDebugNoWarn == 0) \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return mscclppInternalError; \
|
||||
} \
|
||||
if (tmpAbortFlag) \
|
||||
NEQCHECK(*tmpAbortFlag, 0); \
|
||||
#define MSCCLPPWAIT(call, cond, abortFlagPtr) \
|
||||
do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return mscclppInternalError; \
|
||||
} \
|
||||
if (tmpAbortFlag) NEQCHECK(*tmpAbortFlag, 0); \
|
||||
} while (!(cond));
|
||||
|
||||
#define MSCCLPPWAITGOTO(call, cond, abortFlagPtr, res, label) \
|
||||
do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
if (mscclppDebugNoWarn == 0) \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
if (tmpAbortFlag) \
|
||||
NEQCHECKGOTO(*tmpAbortFlag, 0, res, label); \
|
||||
#define MSCCLPPWAITGOTO(call, cond, abortFlagPtr, res, label) \
|
||||
do { \
|
||||
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
|
||||
res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
goto label; \
|
||||
} \
|
||||
if (tmpAbortFlag) NEQCHECKGOTO(*tmpAbortFlag, 0, res, label); \
|
||||
} while (!(cond));
|
||||
|
||||
#define MSCCLPPCHECKTHREAD(a, args) \
|
||||
do { \
|
||||
if (((args)->ret = (a)) != mscclppSuccess && (args)->ret != mscclppInProgress) { \
|
||||
INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, (args)->ret); \
|
||||
return args; \
|
||||
} \
|
||||
#define MSCCLPPCHECKTHREAD(a, args) \
|
||||
do { \
|
||||
if (((args)->ret = (a)) != mscclppSuccess && (args)->ret != mscclppInProgress) { \
|
||||
INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, (args)->ret); \
|
||||
return args; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CUDACHECKTHREAD(a) \
|
||||
do { \
|
||||
if ((a) != cudaSuccess) { \
|
||||
INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, args->ret); \
|
||||
args->ret = mscclppUnhandledCudaError; \
|
||||
return args; \
|
||||
} \
|
||||
#define CUDACHECKTHREAD(a) \
|
||||
do { \
|
||||
if ((a) != cudaSuccess) { \
|
||||
INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, args->ret); \
|
||||
args->ret = mscclppUnhandledCudaError; \
|
||||
return args; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -7,37 +7,38 @@
|
||||
#ifndef MSCCLPP_CHECKS_HPP_
|
||||
#define MSCCLPP_CHECKS_HPP_
|
||||
|
||||
#include "debug.h"
|
||||
#include "errors.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#define MSCCLPPTHROW(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
throw mscclpp::Error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res), \
|
||||
res); \
|
||||
} \
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
#define MSCCLPPTHROW(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
throw mscclpp::Error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res), \
|
||||
ErrorCode::InvalidUsage); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw mscclpp::CudaError(std::string("Cuda failure '") + cudaGetErrorString(err) + "'", err); \
|
||||
} \
|
||||
#define CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw mscclpp::CudaError(std::string("Cuda failure '") + cudaGetErrorString(err) + "'", err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
const char* errStr; \
|
||||
cuGetErrorString(err, &errStr); \
|
||||
throw mscclpp::CuError(std::string("Cu failure '") + std::string(errStr) + "'", err); \
|
||||
} \
|
||||
#define CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
const char* errStr; \
|
||||
cuGetErrorString(err, &errStr); \
|
||||
throw mscclpp::CuError(std::string("Cu failure '") + std::string(errStr) + "'", err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCLPP_COMM_H_
|
||||
#define MSCCLPP_COMM_H_
|
||||
|
||||
#include "ib.hpp"
|
||||
#include "proxy.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
struct mscclppBufferRegistration
|
||||
{
|
||||
void* data;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct mscclppConn
|
||||
{
|
||||
int connId;
|
||||
mscclppTransport_t transport;
|
||||
int remoteRank;
|
||||
uint64_t buffSize;
|
||||
struct mscclppDevConn* devConn;
|
||||
struct mscclppHostConn* hostConn;
|
||||
|
||||
std::vector<mscclppBufferRegistration> bufferRegistrations;
|
||||
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
|
||||
|
||||
mscclpp::IbCtx* ibCtx;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
std::vector<uint64_t> npkitFreeReqIds;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct mscclppComm
|
||||
{
|
||||
struct mscclppConn conns[MAXCONNECTIONS];
|
||||
struct mscclppDevConn devConns[MAXCONNECTIONS];
|
||||
int nConns;
|
||||
|
||||
void* bootstrap;
|
||||
|
||||
// Magic number for all network communication. Not a security key -- only goal is to detect mismatches.
|
||||
uint64_t magic;
|
||||
|
||||
int rank; // my rank in the communicator
|
||||
int nRanks; // number of GPUs in communicator
|
||||
int cudaDev; // my cuda device index
|
||||
int devNumaNode; // my device's NUMA node
|
||||
|
||||
// Flag to ask MSCCLPP kernels to abort
|
||||
volatile uint32_t* abortFlag;
|
||||
|
||||
std::unique_ptr<mscclpp::IbCtx> ibContext[MSCCLPP_IB_MAX_DEVS];
|
||||
struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM];
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,19 +1,19 @@
|
||||
#ifndef MSCCL_COMMUNICATOR_HPP_
|
||||
#define MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
#include <memory>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class ConnectionBase;
|
||||
|
||||
struct Communicator::Impl
|
||||
{
|
||||
struct Communicator::Impl {
|
||||
std::vector<std::shared_ptr<ConnectionBase>> connections_;
|
||||
std::vector<std::shared_ptr<Setuppable>> toSetup_;
|
||||
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
|
||||
@@ -27,6 +27,6 @@ struct Communicator::Impl
|
||||
IbCtx* getIbContext(Transport ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCL_COMMUNICATOR_HPP_
|
||||
#endif // MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
|
||||
#include <time.h>
|
||||
|
||||
class mscclppConfig
|
||||
{
|
||||
public:
|
||||
class mscclppConfig {
|
||||
public:
|
||||
time_t bootstrapConnectionTimeout = 30;
|
||||
|
||||
static mscclppConfig* getInstance();
|
||||
time_t getBootstrapConnectionTimeoutConfig();
|
||||
void setBootstrapConnectionTimeoutConfig(time_t timeout);
|
||||
|
||||
private:
|
||||
private:
|
||||
mscclppConfig() = default;
|
||||
mscclppConfig(const mscclppConfig&) = delete;
|
||||
mscclppConfig& operator=(const mscclppConfig&) = delete;
|
||||
@@ -20,4 +19,4 @@ private:
|
||||
static mscclppConfig _instance;
|
||||
};
|
||||
|
||||
#endif // end include guard
|
||||
#endif // end include guard
|
||||
|
||||
@@ -2,34 +2,34 @@
|
||||
#define MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
// TODO(saemal): make this configurable
|
||||
#define MSCCLPP_POLLING_WAIT 10000 // in microseconds
|
||||
#define MSCCLPP_POLLING_WAIT 3e7 // in microseconds
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
// TODO: Add functionality to these classes for Communicator to do connectionSetup
|
||||
|
||||
class ConnectionBase : public Connection, public Setuppable
|
||||
{
|
||||
class ConnectionBase : public Connection, public Setuppable {
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
|
||||
public:
|
||||
public:
|
||||
ConnectionBase(int remoteRank, int tag);
|
||||
|
||||
int remoteRank() override;
|
||||
int tag() override;
|
||||
};
|
||||
|
||||
class CudaIpcConnection : public ConnectionBase
|
||||
{
|
||||
class CudaIpcConnection : public ConnectionBase {
|
||||
cudaStream_t stream;
|
||||
|
||||
public:
|
||||
public:
|
||||
CudaIpcConnection(int remoteRank, int tag);
|
||||
|
||||
~CudaIpcConnection();
|
||||
@@ -44,14 +44,13 @@ public:
|
||||
void flush() override;
|
||||
};
|
||||
|
||||
class IBConnection : public ConnectionBase
|
||||
{
|
||||
class IBConnection : public ConnectionBase {
|
||||
Transport transport_;
|
||||
Transport remoteTransport_;
|
||||
IbQp* qp;
|
||||
int numSignaledSends;
|
||||
|
||||
public:
|
||||
public:
|
||||
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
|
||||
|
||||
Transport transport() override;
|
||||
@@ -68,6 +67,6 @@ public:
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CONNECTION_HPP_
|
||||
#endif // MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
@@ -7,20 +7,20 @@
|
||||
#ifndef MSCCLPP_DEBUG_H_
|
||||
#define MSCCLPP_DEBUG_H_
|
||||
|
||||
#include "mscclpp.h"
|
||||
#include <chrono>
|
||||
#include <stdio.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <limits.h>
|
||||
#include <pthread.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mscclpp.h"
|
||||
|
||||
// Conform to pthread and NVTX standard
|
||||
#define MSCCLPP_THREAD_NAMELEN 16
|
||||
|
||||
typedef enum
|
||||
{
|
||||
typedef enum {
|
||||
MSCCLPP_LOG_NONE = 0,
|
||||
MSCCLPP_LOG_VERSION = 1,
|
||||
MSCCLPP_LOG_WARN = 2,
|
||||
@@ -28,8 +28,7 @@ typedef enum
|
||||
MSCCLPP_LOG_ABORT = 4,
|
||||
MSCCLPP_LOG_TRACE = 5
|
||||
} mscclppDebugLogLevel;
|
||||
typedef enum
|
||||
{
|
||||
typedef enum {
|
||||
MSCCLPP_INIT = 1,
|
||||
MSCCLPP_COLL = 2,
|
||||
MSCCLPP_P2P = 4,
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#ifndef MSCCLPP_EPOCH_HPP_
|
||||
#define MSCCLPP_EPOCH_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct alignas(16) EpochIds
|
||||
{
|
||||
uint64_t outbound_;
|
||||
uint64_t inboundReplica_;
|
||||
};
|
||||
|
||||
struct DeviceEpoch
|
||||
{
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*expectedInboundEpochId_) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(epochIds_->outbound_) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
EpochIds* epochIds_;
|
||||
uint64_t* expectedInboundEpochId_;
|
||||
};
|
||||
|
||||
class Epoch
|
||||
{
|
||||
std::shared_ptr<Connection> connection_;
|
||||
DeviceEpoch device_;
|
||||
RegisteredMemory localEpochIdsRegMem_;
|
||||
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
|
||||
|
||||
public:
|
||||
Epoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
Epoch(const Epoch&) = delete;
|
||||
~Epoch();
|
||||
|
||||
void signal();
|
||||
|
||||
DeviceEpoch deviceEpoch()
|
||||
{
|
||||
return device_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
@@ -1,156 +0,0 @@
|
||||
#ifndef MSCCLPP_GDR_H_
|
||||
#define MSCCLPP_GDR_H_
|
||||
|
||||
#include "align.h"
|
||||
#include "alloc.h"
|
||||
#include "checks.h"
|
||||
#include "debug.h"
|
||||
#include "gdrapi.h"
|
||||
|
||||
// These can be used if the GDR library isn't thread safe
|
||||
#include <pthread.h>
|
||||
extern pthread_mutex_t gdrLock;
|
||||
#define GDRLOCK() pthread_mutex_lock(&gdrLock)
|
||||
#define GDRUNLOCK() pthread_mutex_unlock(&gdrLock)
|
||||
#define GDRLOCKCALL(cmd, ret) \
|
||||
do { \
|
||||
GDRLOCK(); \
|
||||
ret = cmd; \
|
||||
GDRUNLOCK(); \
|
||||
} while (false)
|
||||
|
||||
#define GDRCHECK(cmd) \
|
||||
do { \
|
||||
int e; \
|
||||
/* GDRLOCKCALL(cmd, e); */ \
|
||||
e = cmd; \
|
||||
if (e != 0) { \
|
||||
WARN("GDRCOPY failure %d", e); \
|
||||
return mscclppSystemError; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
gdr_t wrap_gdr_open(void);
|
||||
mscclppResult_t wrap_gdr_close(gdr_t g);
|
||||
mscclppResult_t wrap_gdr_pin_buffer(gdr_t g, unsigned long addr, size_t size, uint64_t p2p_token, uint32_t va_space,
|
||||
gdr_mh_t* handle);
|
||||
mscclppResult_t wrap_gdr_unpin_buffer(gdr_t g, gdr_mh_t handle);
|
||||
mscclppResult_t wrap_gdr_get_info(gdr_t g, gdr_mh_t handle, gdr_info_t* info);
|
||||
mscclppResult_t wrap_gdr_map(gdr_t g, gdr_mh_t handle, void** va, size_t size);
|
||||
mscclppResult_t wrap_gdr_unmap(gdr_t g, gdr_mh_t handle, void* va, size_t size);
|
||||
|
||||
// Global GDR driver handle
|
||||
extern gdr_t mscclppGdrCopy;
|
||||
|
||||
typedef struct gdr_mem_desc
|
||||
{
|
||||
void* gdrDevMem;
|
||||
void* gdrMap;
|
||||
size_t gdrOffset;
|
||||
size_t gdrMapSize;
|
||||
gdr_mh_t gdrMh;
|
||||
} gdr_mem_desc_t;
|
||||
|
||||
static gdr_t mscclppGdrInit()
|
||||
{
|
||||
// int libMajor, libMinor, drvMajor, drvMinor;
|
||||
gdr_t handle = wrap_gdr_open();
|
||||
|
||||
// if (handle != NULL) {
|
||||
// mscclppResult_t res;
|
||||
|
||||
// // Query the version of libgdrapi
|
||||
// MSCCLPPCHECKGOTO(wrap_gdr_runtime_get_version(&libMajor, &libMinor), res, error);
|
||||
|
||||
// // Query the version of gdrdrv driver
|
||||
// MSCCLPPCHECKGOTO(wrap_gdr_driver_get_version(handle, &drvMajor, &drvMinor), res, error);
|
||||
|
||||
// // Only support GDRAPI 2.1 and later
|
||||
// if (libMajor < 2 || (libMajor == 2 && libMinor < 1) || drvMajor < 2 || (drvMajor == 2 && drvMinor < 1)) {
|
||||
// goto error;
|
||||
// }
|
||||
// else
|
||||
// INFO(MSCCLPP_INIT, "GDRCOPY enabled library %d.%d driver %d.%d", libMajor, libMinor, drvMajor, drvMinor);
|
||||
// }
|
||||
return handle;
|
||||
// error:
|
||||
// if (handle != NULL) (void) wrap_gdr_close(handle);
|
||||
// return NULL;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
mscclppResult_t mscclppGdrCudaCallocDebug(T** ptr, T** devPtr, size_t nelem, void** gdrDesc, const char* filefunc,
|
||||
int line)
|
||||
{
|
||||
mscclppResult_t result = mscclppSuccess;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
*ptr = nullptr;
|
||||
*devPtr = nullptr;
|
||||
*gdrDesc = nullptr;
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
gdr_info_t info;
|
||||
size_t mapSize;
|
||||
gdr_mh_t mh;
|
||||
char* devMem;
|
||||
void* gdrMap;
|
||||
ssize_t off;
|
||||
gdr_mem_desc_t* md;
|
||||
uint64_t alignedAddr;
|
||||
size_t align;
|
||||
|
||||
mapSize = sizeof(T) * nelem;
|
||||
|
||||
// GDRCOPY Pinned buffer has to be a minimum of a GPU_PAGE_SIZE
|
||||
ALIGN_SIZE(mapSize, GPU_PAGE_SIZE);
|
||||
// GDRCOPY Pinned buffer has to be GPU_PAGE_SIZE aligned too
|
||||
MSCCLPPCHECKGOTO(mscclppCudaCalloc(&devMem, mapSize + GPU_PAGE_SIZE - 1), result, finish);
|
||||
alignedAddr = (((uint64_t)devMem) + GPU_PAGE_OFFSET) & GPU_PAGE_MASK;
|
||||
align = alignedAddr - (uint64_t)devMem;
|
||||
MSCCLPPCHECKGOTO(wrap_gdr_pin_buffer(mscclppGdrCopy, alignedAddr, mapSize, 0, 0, &mh), result, finish);
|
||||
|
||||
MSCCLPPCHECKGOTO(wrap_gdr_map(mscclppGdrCopy, mh, &gdrMap, mapSize), result, finish);
|
||||
|
||||
MSCCLPPCHECKGOTO(wrap_gdr_get_info(mscclppGdrCopy, mh, &info), result, finish);
|
||||
|
||||
// Will offset ever be non zero ?
|
||||
off = info.va - alignedAddr;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&md, 1), result, finish);
|
||||
md->gdrDevMem = devMem;
|
||||
md->gdrMap = gdrMap;
|
||||
md->gdrMapSize = mapSize;
|
||||
md->gdrOffset = off + align;
|
||||
md->gdrMh = mh;
|
||||
*gdrDesc = md;
|
||||
|
||||
*ptr = (T*)((char*)gdrMap + off);
|
||||
if (devPtr)
|
||||
*devPtr = (T*)(devMem + off + align);
|
||||
|
||||
TRACE(mscclpp_INIT, "GDRCOPY : allocated devMem %p gdrMap %p offset %lx mh %lx mapSize %zi at %p", md->gdrDevMem,
|
||||
md->gdrMap, md->gdrOffset, md->gdrMh.h, md->gdrMapSize, *ptr);
|
||||
|
||||
return mscclppSuccess;
|
||||
|
||||
finish:
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
if (*ptr == nullptr)
|
||||
WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T));
|
||||
INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
|
||||
return result;
|
||||
}
|
||||
#define mscclppGdrCudaCalloc(...) mscclppGdrCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
static mscclppResult_t mscclppGdrCudaFree(void* gdrDesc)
|
||||
{
|
||||
gdr_mem_desc_t* md = (gdr_mem_desc_t*)gdrDesc;
|
||||
MSCCLPPCHECK(wrap_gdr_unmap(mscclppGdrCopy, md->gdrMh, md->gdrMap, md->gdrMapSize));
|
||||
MSCCLPPCHECK(wrap_gdr_unpin_buffer(mscclppGdrCopy, md->gdrMh));
|
||||
CUDACHECK(cudaFree(md->gdrDevMem));
|
||||
free(md);
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -12,22 +12,20 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IbMrInfo
|
||||
{
|
||||
struct IbMrInfo {
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
class IbMr
|
||||
{
|
||||
public:
|
||||
class IbMr {
|
||||
public:
|
||||
~IbMr();
|
||||
|
||||
IbMrInfo getInfo() const;
|
||||
const void* getBuff() const;
|
||||
uint32_t getLkey() const;
|
||||
|
||||
private:
|
||||
private:
|
||||
IbMr(void* pd, void* buff, std::size_t size);
|
||||
|
||||
void* mr;
|
||||
@@ -38,8 +36,7 @@ private:
|
||||
};
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct IbQpInfo
|
||||
{
|
||||
struct IbQpInfo {
|
||||
uint16_t lid;
|
||||
uint8_t port;
|
||||
uint8_t linkLayer;
|
||||
@@ -50,9 +47,8 @@ struct IbQpInfo
|
||||
bool is_grh;
|
||||
};
|
||||
|
||||
class IbQp
|
||||
{
|
||||
public:
|
||||
class IbQp {
|
||||
public:
|
||||
~IbQp();
|
||||
|
||||
void rtr(const IbQpInfo& info);
|
||||
@@ -68,7 +64,7 @@ public:
|
||||
IbQpInfo& getInfo();
|
||||
const void* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
IbQp(void* ctx, void* pd, int port);
|
||||
|
||||
IbQpInfo info;
|
||||
@@ -83,9 +79,8 @@ private:
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
class IbCtx
|
||||
{
|
||||
public:
|
||||
class IbCtx {
|
||||
public:
|
||||
IbCtx(const std::string& devName);
|
||||
~IbCtx();
|
||||
|
||||
@@ -94,7 +89,7 @@ public:
|
||||
|
||||
const std::string& getDevName() const;
|
||||
|
||||
private:
|
||||
private:
|
||||
bool isPortUsable(int port) const;
|
||||
int getAnyActivePort() const;
|
||||
|
||||
@@ -105,6 +100,6 @@ private:
|
||||
std::list<std::unique_ptr<IbMr>> mrs;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
|
||||
|
||||
#include <mscclppfifo.h>
|
||||
|
||||
#include <vector>
|
||||
// #includa <cuda_runtime.h>
|
||||
|
||||
@@ -19,8 +20,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct alignas(16) mscclppDevConnSignalEpochId
|
||||
{
|
||||
struct alignas(16) mscclppDevConnSignalEpochId {
|
||||
// every signal(), increaments this and either:
|
||||
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
|
||||
// 2) gpu thread directly writes it to remoteSignalEpochId->device
|
||||
@@ -93,39 +93,30 @@ using mscclppBufferHandle_t = uint32_t;
|
||||
* The two endpoint can concurrently use the same connection provided they are writing (puts) on different
|
||||
* indices in the registered buffer.
|
||||
**************************************************************************************************************/
|
||||
struct mscclppDevConn
|
||||
{
|
||||
struct mscclppDevConn {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
__forceinline__ __device__ void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) {
|
||||
fifo.push(mscclppData, dstDataOffset, srcDataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dataOffset, uint64_t dataSize)
|
||||
{
|
||||
put(dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
__forceinline__ __device__ void put(uint64_t dataOffset, uint64_t dataSize) { put(dataOffset, dataOffset, dataSize); }
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
__forceinline__ __device__ void signal() {
|
||||
epochIncrement();
|
||||
fifo.push(mscclppFlag, 0, 0, 1);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) {
|
||||
epochIncrement();
|
||||
fifo.push(mscclppData | mscclppFlag, dstDataOffset, srcDataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize) {
|
||||
putWithSignal(dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
uint64_t dataSize) {
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize);
|
||||
while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
@@ -133,13 +124,11 @@ struct mscclppDevConn
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize)
|
||||
{
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize) {
|
||||
putWithSignalAndFlush(dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
__forceinline__ __device__ void flush() {
|
||||
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
@@ -150,26 +139,23 @@ struct mscclppDevConn
|
||||
|
||||
// Version that uses the SM directly to do the copy, instead of using the proxy thread like the functions above.
|
||||
__forceinline__ __device__ void putDirect(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize,
|
||||
uint32_t threadId, uint32_t numThreads)
|
||||
{
|
||||
uint32_t threadId, uint32_t numThreads) {
|
||||
uint64_t* src = (uint64_t*)((char*)localBuff + srcDataOffset);
|
||||
uint64_t* dst = (uint64_t*)((char*)remoteBuff + dstDataOffset);
|
||||
// assume the memory is aligned to 8 bytes
|
||||
size_t nElem =
|
||||
dataSize % sizeof(uint64_t) ? (dataSize + sizeof(uint64_t)) / sizeof(uint64_t) : dataSize / sizeof(uint64_t);
|
||||
dataSize % sizeof(uint64_t) ? (dataSize + sizeof(uint64_t)) / sizeof(uint64_t) : dataSize / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putDirect(uint64_t dataOffset, uint64_t dataSize, uint32_t threadId,
|
||||
uint32_t numThreads)
|
||||
{
|
||||
uint32_t numThreads) {
|
||||
putDirect(dataOffset, dataOffset, dataSize, threadId, numThreads);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signalDirect()
|
||||
{
|
||||
__forceinline__ __device__ void signalDirect() {
|
||||
// This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the
|
||||
// incremented epoch id is visible.
|
||||
__threadfence_system();
|
||||
@@ -177,26 +163,21 @@ struct mscclppDevConn
|
||||
*(volatile uint64_t*)&(remoteSignalEpochId->device) = localSignalEpochId->device;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void waitDirect()
|
||||
{
|
||||
__forceinline__ __device__ void waitDirect() {
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->device) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
}
|
||||
__forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(localSignalEpochId->device) += 1; }
|
||||
|
||||
#endif // __CUDACC__
|
||||
#endif // __CUDACC__
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
@@ -223,8 +204,7 @@ struct mscclppDevConn
|
||||
};
|
||||
|
||||
// Host interface for mscclppDevCon functionality
|
||||
struct mscclppHostConn
|
||||
{
|
||||
struct mscclppHostConn {
|
||||
virtual ~mscclppHostConn() = default;
|
||||
virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0;
|
||||
virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
@@ -239,25 +219,21 @@ typedef struct mscclppDevConn mscclppDevConn_t;
|
||||
typedef struct mscclppHostConn mscclppHostConn_t;
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
typedef struct
|
||||
{
|
||||
typedef struct {
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
} mscclppUniqueId;
|
||||
|
||||
struct mscclppRegisteredMemoryP2P
|
||||
{
|
||||
struct mscclppRegisteredMemoryP2P {
|
||||
void* remoteBuff;
|
||||
const void* IbMr;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemory
|
||||
{
|
||||
struct mscclppRegisteredMemory {
|
||||
std::vector<mscclppRegisteredMemoryP2P> p2p;
|
||||
};
|
||||
|
||||
/* Error type */
|
||||
typedef enum
|
||||
{
|
||||
typedef enum {
|
||||
mscclppSuccess = 0,
|
||||
mscclppUnhandledCudaError = 1,
|
||||
mscclppSystemError = 2,
|
||||
@@ -279,10 +255,9 @@ typedef enum
|
||||
mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* uniqueId);
|
||||
|
||||
/* Transport Types */
|
||||
typedef enum
|
||||
{
|
||||
typedef enum {
|
||||
mscclppTransportP2P = 0,
|
||||
mscclppTransportSHM = 1, // TODO(chhwang): not implemented yet
|
||||
mscclppTransportSHM = 1, // TODO(chhwang): not implemented yet
|
||||
mscclppTransportIB = 2,
|
||||
} mscclppTransport_t;
|
||||
|
||||
@@ -520,7 +495,7 @@ mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegister
|
||||
size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
} // end extern "C"
|
||||
#endif
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
#endif // MSCCLPP_H_
|
||||
|
||||
@@ -7,12 +7,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef enum : uint64_t
|
||||
{
|
||||
mscclppData = 0x1,
|
||||
mscclppFlag = 0x2,
|
||||
mscclppSync = 0x4
|
||||
} mscclppTriggerType_t;
|
||||
typedef enum : uint64_t { mscclppData = 0x1, mscclppFlag = 0x2, mscclppSync = 0x4 } mscclppTriggerType_t;
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
@@ -23,17 +18,16 @@ typedef enum : uint64_t
|
||||
// the summation of number of bits must be 128 or less
|
||||
union alignas(16) mscclppTrigger {
|
||||
uint64_t value[2];
|
||||
struct
|
||||
{
|
||||
struct {
|
||||
// first 64 bits: value[0]
|
||||
uint64_t dataSize : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcDataOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// second 64 bits: value[1]
|
||||
uint64_t dstDataOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_CONNID - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_CONNID - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
};
|
||||
|
||||
@@ -52,13 +46,11 @@ typedef mscclppTrigger* mscclppTrigger_t;
|
||||
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct mscclppConcurrentFifo
|
||||
{
|
||||
struct mscclppConcurrentFifo {
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ uint64_t push(uint64_t type, uint64_t dstDataOffset, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
uint64_t dataSize) {
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1);
|
||||
while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail))
|
||||
;
|
||||
@@ -71,16 +63,16 @@ struct mscclppConcurrentFifo
|
||||
return curFifoHead;
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
mscclppTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
|
||||
#endif // __CUDACC__
|
||||
mscclppTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device
|
||||
int connId;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
} // end extern "C"
|
||||
#endif
|
||||
|
||||
#endif // MSCCLPPFIFO_H_
|
||||
#endif // MSCCLPPFIFO_H_
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
#ifndef MSCCLPP_PROXY_H_
|
||||
#define MSCCLPP_PROXY_H_
|
||||
|
||||
#include "comm.h"
|
||||
#include "mscclpp.h"
|
||||
#include <atomic>
|
||||
#include <cuda_runtime.h>
|
||||
#include <pthread.h>
|
||||
|
||||
#define MSCCLPP_PROXY_MAX_NUM (MSCCLPP_IB_MAX_DEVS + 1) // One is for a P2P proxy.
|
||||
#include <atomic>
|
||||
|
||||
typedef enum
|
||||
{
|
||||
#include "comm.h"
|
||||
#include "mscclpp.h"
|
||||
|
||||
#define MSCCLPP_PROXY_MAX_NUM (MSCCLPP_IB_MAX_DEVS + 1) // One is for a P2P proxy.
|
||||
|
||||
typedef enum {
|
||||
MSCCLPP_PROXY_RUN_STATE_IDLE = 0,
|
||||
MSCCLPP_PROXY_RUN_STATE_RUNNING,
|
||||
MSCCLPP_PROXY_RUN_STATE_EXITING,
|
||||
} mscclppProxyRunState_t;
|
||||
|
||||
struct mscclppProxyFifo
|
||||
{
|
||||
struct mscclppProxyFifo {
|
||||
mscclppResult_t create();
|
||||
mscclppResult_t destroy();
|
||||
mscclppResult_t poll(mscclppTrigger* trigger);
|
||||
@@ -52,15 +52,14 @@ struct mscclppProxyFifo
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
struct mscclppProxyState
|
||||
{
|
||||
struct mscclppProxyState {
|
||||
mscclppTransport_t transportType;
|
||||
pthread_t thread;
|
||||
mscclppProxyRunState_t run;
|
||||
|
||||
int numaNodeToBind;
|
||||
mscclpp::IbCtx* ibContext; // For IB connection only
|
||||
cudaStream_t p2pStream; // for P2P DMA engine only
|
||||
mscclpp::IbCtx* ibContext; // For IB connection only
|
||||
cudaStream_t p2pStream; // for P2P DMA engine only
|
||||
|
||||
struct mscclppProxyFifo fifo;
|
||||
};
|
||||
|
||||
@@ -1,37 +1,35 @@
|
||||
#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
#define MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "errors.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct TransportInfo
|
||||
{
|
||||
struct TransportInfo {
|
||||
Transport transport;
|
||||
|
||||
// TODO: rewrite this using std::variant or something
|
||||
bool ibLocal;
|
||||
union {
|
||||
struct
|
||||
{
|
||||
struct {
|
||||
cudaIpcMemHandle_t cudaIpcBaseHandle;
|
||||
size_t cudaIpcOffsetFromBase;
|
||||
};
|
||||
struct
|
||||
{
|
||||
struct {
|
||||
const IbMr* ibMr;
|
||||
IbMrInfo ibMrInfo;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct RegisteredMemory::Impl
|
||||
{
|
||||
struct RegisteredMemory::Impl {
|
||||
void* data;
|
||||
size_t size;
|
||||
int rank;
|
||||
@@ -42,17 +40,16 @@ struct RegisteredMemory::Impl
|
||||
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
TransportInfo& getTransportInfo(Transport transport)
|
||||
{
|
||||
TransportInfo& getTransportInfo(Transport transport) {
|
||||
for (auto& entry : transportInfos) {
|
||||
if (entry.transport == transport) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
throw Error("Transport data not found", mscclppInternalError);
|
||||
throw Error("Transport data not found", ErrorCode::InternalError);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
#ifndef MSCCLPP_REGISTERED_PTR_HPP_
|
||||
#define MSCCLPP_REGISTERED_PTR_HPP_
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
template <typename T> class RegisteredPtr
|
||||
{
|
||||
RegisteredMemory memory;
|
||||
size_t offset;
|
||||
|
||||
public:
|
||||
RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset)
|
||||
{
|
||||
}
|
||||
RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0)
|
||||
{
|
||||
}
|
||||
~RegisteredPtr()
|
||||
{
|
||||
}
|
||||
|
||||
RegisteredMemory memory()
|
||||
{
|
||||
return memory;
|
||||
}
|
||||
|
||||
T* data()
|
||||
{
|
||||
return reinterpret_cast<T*>(memory.data());
|
||||
}
|
||||
|
||||
size_t size()
|
||||
{
|
||||
return memory.size() / sizeof(T);
|
||||
}
|
||||
|
||||
size_t offset()
|
||||
{
|
||||
return offset;
|
||||
}
|
||||
|
||||
RegisteredPtr<T> operator+(size_t offset)
|
||||
{
|
||||
return RegisteredPtr<T>(memory, this->offset + offset);
|
||||
}
|
||||
|
||||
// TODO: all other relevant overloads
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_REGISTERED_PTR_HPP_
|
||||
@@ -7,7 +7,6 @@
|
||||
#ifndef MSCCLPP_SOCKET_H_
|
||||
#define MSCCLPP_SOCKET_H_
|
||||
|
||||
#include "mscclpp.h"
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
@@ -16,9 +15,11 @@
|
||||
#include <stddef.h>
|
||||
#include <sys/socket.h>
|
||||
|
||||
#include "mscclpp.h"
|
||||
|
||||
#define MAX_IFS 16
|
||||
#define MAX_IF_NAME_SIZE 16
|
||||
#define SLEEP_INT 1000 // connection retry sleep interval in usec
|
||||
#define SLEEP_INT 1000 // connection retry sleep interval in usec
|
||||
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV)
|
||||
#define MSCCLPP_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL
|
||||
|
||||
@@ -29,8 +30,7 @@ union mscclppSocketAddress {
|
||||
struct sockaddr_in6 sin6;
|
||||
};
|
||||
|
||||
enum mscclppSocketState
|
||||
{
|
||||
enum mscclppSocketState {
|
||||
mscclppSocketStateNone = 0,
|
||||
mscclppSocketStateInitialized = 1,
|
||||
mscclppSocketStateAccepting = 2,
|
||||
@@ -44,8 +44,7 @@ enum mscclppSocketState
|
||||
mscclppSocketStateNum = 10
|
||||
};
|
||||
|
||||
enum mscclppSocketType
|
||||
{
|
||||
enum mscclppSocketType {
|
||||
mscclppSocketTypeUnknown = 0,
|
||||
mscclppSocketTypeBootstrap = 1,
|
||||
mscclppSocketTypeProxy = 2,
|
||||
@@ -53,8 +52,7 @@ enum mscclppSocketType
|
||||
mscclppSocketTypeNetIb = 4
|
||||
};
|
||||
|
||||
struct mscclppSocket
|
||||
{
|
||||
struct mscclppSocket {
|
||||
int fd;
|
||||
int acceptFd;
|
||||
int connectRetries;
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#ifndef MSCCLPP_UTILS_H_
|
||||
#define MSCCLPP_UTILS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "mscclpp.h"
|
||||
#include <chrono>
|
||||
#include <stdint.h>
|
||||
|
||||
// int mscclppCudaCompCap();
|
||||
|
||||
@@ -27,8 +29,7 @@ uint64_t getHostHash();
|
||||
uint64_t getPidHash();
|
||||
mscclppResult_t getRandomData(void* buffer, size_t bytes);
|
||||
|
||||
struct netIf
|
||||
{
|
||||
struct netIf {
|
||||
char prefix[64];
|
||||
int port;
|
||||
};
|
||||
@@ -36,11 +37,9 @@ struct netIf
|
||||
int parseStringList(const char* string, struct netIf* ifList, int maxList);
|
||||
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact);
|
||||
|
||||
static long log2i(long n)
|
||||
{
|
||||
static long log2i(long n) {
|
||||
long l = 0;
|
||||
while (n >>= 1)
|
||||
l++;
|
||||
while (n >>= 1) l++;
|
||||
return l;
|
||||
}
|
||||
|
||||
@@ -50,16 +49,13 @@ int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end);
|
||||
|
||||
/* get any bytes of random data from /dev/urandom, return 0 if it succeeds; else
|
||||
* return -1 */
|
||||
inline mscclppResult_t getRandomData(void* buffer, size_t bytes)
|
||||
{
|
||||
inline mscclppResult_t getRandomData(void* buffer, size_t bytes) {
|
||||
mscclppResult_t ret = mscclppSuccess;
|
||||
if (bytes > 0) {
|
||||
const size_t one = 1UL;
|
||||
FILE* fp = fopen("/dev/urandom", "r");
|
||||
if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one)
|
||||
ret = mscclppSystemError;
|
||||
if (fp)
|
||||
fclose(fp);
|
||||
if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) ret = mscclppSystemError;
|
||||
if (fp) fclose(fp);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -1,54 +1,40 @@
|
||||
#ifndef MSCCLPP_UTILS_HPP_
|
||||
#define MSCCLPP_UTILS_HPP_
|
||||
|
||||
#include <chrono>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Timer
|
||||
{
|
||||
struct Timer {
|
||||
std::chrono::steady_clock::time_point start;
|
||||
|
||||
Timer()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
Timer() { start = std::chrono::steady_clock::now(); }
|
||||
|
||||
int64_t elapsed()
|
||||
{
|
||||
int64_t elapsed() {
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
void reset() { start = std::chrono::steady_clock::now(); }
|
||||
|
||||
void print(const char* name)
|
||||
{
|
||||
void print(const char* name) {
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
printf("%s: %ld us\n", name, elapsed);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScopedTimer
|
||||
{
|
||||
struct ScopedTimer {
|
||||
Timer timer;
|
||||
const char* name;
|
||||
|
||||
ScopedTimer(const char* name) : name(name)
|
||||
{
|
||||
}
|
||||
ScopedTimer(const char* name) : name(name) {}
|
||||
|
||||
~ScopedTimer()
|
||||
{
|
||||
timer.print(name);
|
||||
}
|
||||
~ScopedTimer() { timer.print(name); }
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_UTILS_HPP_
|
||||
#endif // MSCCLPP_UTILS_HPP_
|
||||
|
||||
920
src/init.cc
920
src/init.cc
@@ -1,920 +0,0 @@
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "bootstrap.h"
|
||||
#include "checks.h"
|
||||
#include "config.h"
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
#include "gdr.h"
|
||||
#endif
|
||||
#include "infiniband/verbs.h"
|
||||
#include "mscclpp.h"
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
static uint64_t hashUniqueId(mscclppUniqueId const& id)
|
||||
{
|
||||
char const* bytes = (char const*)&id;
|
||||
uint64_t h = 0xdeadbeef;
|
||||
for (int i = 0; i < (int)sizeof(mscclppUniqueId); i++) {
|
||||
h ^= h >> 32;
|
||||
h *= 0x8db3db47fa2994ad;
|
||||
h += bytes[i];
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER;
|
||||
static bool initialized = false;
|
||||
// static size_t maxLocalSizeBytes = 0;
|
||||
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
|
||||
gdr_t mscclppGdrCopy = NULL;
|
||||
|
||||
mscclppResult_t initGdrCopy()
|
||||
{
|
||||
if (mscclppGdrCopy == NULL) {
|
||||
mscclppGdrCopy = mscclppGdrInit();
|
||||
if (mscclppGdrCopy == NULL) {
|
||||
WARN("GDR init failed");
|
||||
return mscclppSystemError;
|
||||
}
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static mscclppResult_t mscclppInit()
|
||||
{
|
||||
if (__atomic_load_n(&initialized, __ATOMIC_ACQUIRE))
|
||||
return mscclppSuccess;
|
||||
pthread_mutex_lock(&initLock);
|
||||
if (!initialized) {
|
||||
// Always initialize bootstrap network
|
||||
MSCCLPPCHECK(bootstrapNetInit());
|
||||
|
||||
__atomic_store_n(&initialized, true, __ATOMIC_RELEASE);
|
||||
}
|
||||
pthread_mutex_unlock(&initLock);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static std::string mscclppShmFileName(mscclppComm_t comm, int rank)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "mscclpp." << std::hex << comm->magic << "." << rank;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out)
|
||||
{
|
||||
MSCCLPPCHECK(mscclppInit());
|
||||
// mscclppCHECK(PtrCheck(out, "GetUniqueId", "out"));
|
||||
mscclppResult_t res = bootstrapGetUniqueId((struct mscclppBootstrapHandle*)out);
|
||||
TRACE_CALL("mscclppGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out));
|
||||
return res;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size)
|
||||
{
|
||||
MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank)
|
||||
{
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
MSCCLPPCHECK(initGdrCopy());
|
||||
#endif
|
||||
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
mscclppComm_t _comm = NULL;
|
||||
// uint64_t hash = getHostHash();
|
||||
// uint64_t *hashes;
|
||||
// std::map<uint64_t, int> hashToNode;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail);
|
||||
_comm->rank = rank;
|
||||
_comm->nRanks = nranks;
|
||||
_comm->devNumaNode = -1;
|
||||
// We assume that the user has set the device to the intended one already
|
||||
CUDACHECK(cudaGetDevice(&_comm->cudaDev));
|
||||
|
||||
MSCCLPPCHECK(bootstrapNetInit(ipPortPair));
|
||||
mscclppBootstrapHandle handle;
|
||||
MSCCLPPCHECK(bootstrapGetUniqueId(&handle, rank == 0, ipPortPair));
|
||||
_comm->magic = handle.magic;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t**)&_comm->abortFlag, 1), res, fail);
|
||||
MSCCLPPCHECK(bootstrapInit(&handle, _comm));
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
// Init NPKit
|
||||
MSCCLPPCHECK(NpKit::Init(_comm->rank));
|
||||
#endif
|
||||
|
||||
*comm = _comm;
|
||||
return res;
|
||||
fail:
|
||||
if (_comm) {
|
||||
if (_comm->abortFlag)
|
||||
mscclppCudaHostFree((void*)_comm->abortFlag);
|
||||
free(_comm);
|
||||
}
|
||||
if (comm)
|
||||
*comm = NULL;
|
||||
return res;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank)
|
||||
{
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
MSCCLPPCHECK(initGdrCopy());
|
||||
#endif
|
||||
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
mscclppComm_t _comm = NULL;
|
||||
mscclppBootstrapHandle* handle = (mscclppBootstrapHandle*)&id;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail);
|
||||
_comm->rank = rank;
|
||||
_comm->nRanks = nranks;
|
||||
// We assume that the user has set the device to the intended one already
|
||||
CUDACHECK(cudaGetDevice(&_comm->cudaDev));
|
||||
|
||||
MSCCLPPCHECK(bootstrapNetInit());
|
||||
_comm->magic = handle->magic;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t**)&_comm->abortFlag, 1), res, fail);
|
||||
MSCCLPPCHECK(bootstrapInit(handle, _comm));
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
// Init NPKit
|
||||
MSCCLPPCHECK(NpKit::Init(_comm->rank));
|
||||
#endif
|
||||
|
||||
*comm = _comm;
|
||||
return res;
|
||||
fail:
|
||||
if (_comm) {
|
||||
if (_comm->abortFlag)
|
||||
mscclppCudaHostFree((void*)_comm->abortFlag);
|
||||
free(_comm);
|
||||
}
|
||||
if (comm)
|
||||
*comm = NULL;
|
||||
return res;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
|
||||
{
|
||||
#if defined(ENABLE_NPKIT)
|
||||
const char* npkitDumpDir = nullptr;
|
||||
#endif
|
||||
|
||||
if (comm == NULL)
|
||||
return mscclppSuccess;
|
||||
|
||||
for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) {
|
||||
struct mscclppProxyState* proxyState = comm->proxyState[i];
|
||||
if (proxyState) {
|
||||
MSCCLPPCHECK(proxyState->fifo.destroy());
|
||||
if (proxyState->p2pStream)
|
||||
CUDACHECK(cudaStreamDestroy(proxyState->p2pStream));
|
||||
free(proxyState);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i]) {
|
||||
comm->ibContext[i].reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < comm->nConns; i++) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
if (conn) {
|
||||
MSCCLPPCHECK(mscclppCudaFree(conn->devConn->localSignalEpochId));
|
||||
MSCCLPPCHECK(mscclppCudaFree(conn->devConn->waitEpochId));
|
||||
if (conn->hostConn)
|
||||
delete conn->hostConn;
|
||||
}
|
||||
}
|
||||
|
||||
if (comm->bootstrap)
|
||||
MSCCLPPCHECK(bootstrapClose(comm->bootstrap));
|
||||
|
||||
mscclppCudaHostFree((void*)comm->abortFlag);
|
||||
free(comm);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
// Dump NPKit events and shutdown
|
||||
npkitDumpDir = getenv("NPKIT_DUMP_DIR");
|
||||
if (npkitDumpDir == nullptr) {
|
||||
WARN("NPKIT_DUMP_DIR is empty");
|
||||
} else {
|
||||
MSCCLPPCHECK(NpKit::Dump(npkitDumpDir));
|
||||
}
|
||||
MSCCLPPCHECK(NpKit::Shutdown());
|
||||
#endif
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API const char* mscclppGetErrorString(mscclppResult_t code)
|
||||
{
|
||||
switch (code) {
|
||||
case mscclppSuccess:
|
||||
return "no error";
|
||||
case mscclppUnhandledCudaError:
|
||||
return "unhandled cuda error";
|
||||
case mscclppSystemError:
|
||||
return "unhandled system error";
|
||||
case mscclppInternalError:
|
||||
return "internal error";
|
||||
case mscclppInvalidArgument:
|
||||
return "invalid argument";
|
||||
case mscclppInvalidUsage:
|
||||
return "invalid usage";
|
||||
case mscclppRemoteError:
|
||||
return "remote process exited or there was a network error";
|
||||
case mscclppInProgress:
|
||||
return "MSCCL++ operation in progress";
|
||||
default:
|
||||
return "unknown result code";
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag,
|
||||
mscclppDevConn_t** devConn)
|
||||
{
|
||||
for (int i = 0; i < comm->nConns; i++) {
|
||||
if (comm->devConns[i].remoteRank == remoteRank && comm->devConns[i].tag == tag) {
|
||||
*devConn = &comm->devConns[i];
|
||||
return mscclppSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns)
|
||||
{
|
||||
*nConns = comm->nConns;
|
||||
*devConns = comm->devConns;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
|
||||
static void npkitInitReqIds(struct mscclppComm* comm)
|
||||
{
|
||||
for (int i = 0; i < comm->nConns; i++) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
conn->npkitUsedReqIds.resize(0);
|
||||
conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS);
|
||||
for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) {
|
||||
conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size)
|
||||
{
|
||||
uint64_t reqId = 0;
|
||||
if (conn->npkitFreeReqIds.size() == 0) {
|
||||
reqId = conn->npkitUsedReqIds.size();
|
||||
} else {
|
||||
reqId = conn->npkitFreeReqIds.back();
|
||||
conn->npkitFreeReqIds.pop_back();
|
||||
}
|
||||
conn->npkitUsedReqIds.push_back(reqId);
|
||||
NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
|
||||
}
|
||||
|
||||
static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type)
|
||||
{
|
||||
while (conn->npkitUsedReqIds.size()) {
|
||||
uint64_t reqId = conn->npkitUsedReqIds.back();
|
||||
NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId);
|
||||
conn->npkitFreeReqIds.push_back(reqId);
|
||||
conn->npkitUsedReqIds.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define npkitInitReqIds(comm)
|
||||
|
||||
#define npkitCollectEntryEvent(conn, type, size)
|
||||
|
||||
#define npkitCollectExitEvents(conn, type)
|
||||
|
||||
#endif
|
||||
|
||||
struct mscclppHostP2PConn : mscclppHostConn
|
||||
{
|
||||
mscclppHostP2PConn(mscclppConn* _conn, cudaStream_t _stream) : conn(_conn), p2pStream(_stream)
|
||||
{
|
||||
}
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset);
|
||||
void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset);
|
||||
CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal()
|
||||
{
|
||||
CUDACHECKNORET(cudaMemcpyAsync(&conn->devConn->remoteSignalEpochId->proxy,
|
||||
&(conn->devConn->localSignalEpochId->device), sizeof(uint64_t),
|
||||
cudaMemcpyDeviceToDevice, p2pStream));
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait()
|
||||
{
|
||||
}
|
||||
void flush()
|
||||
{
|
||||
CUDACHECKNORET(cudaStreamSynchronize(p2pStream));
|
||||
npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
cudaStream_t p2pStream;
|
||||
};
|
||||
|
||||
struct mscclppHostIBConn : mscclppHostConn
|
||||
{
|
||||
mscclppHostIBConn(mscclppConn* conn) : conn(conn)
|
||||
{
|
||||
this->ibQp = NULL;
|
||||
}
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal()
|
||||
{
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait()
|
||||
{
|
||||
}
|
||||
void flush()
|
||||
{
|
||||
bool isWaiting = true;
|
||||
while (isWaiting) {
|
||||
int wcNum = this->ibQp->pollCq();
|
||||
if (wcNum < 0) {
|
||||
WARN("pollCq failed: errno %d", errno);
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i);
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
continue;
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
isWaiting = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
mscclpp::IbQp* ibQp;
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
std::vector<mscclpp::IbMrInfo> remoteIbMrInfos;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag,
|
||||
mscclppTransport_t transportType, const char* ibDev)
|
||||
{
|
||||
// save this processes numa binding and set it to the one closest to the device
|
||||
// so that all the allocation are close to the device
|
||||
if (comm->devNumaNode == -1) {
|
||||
// in case this is our first time
|
||||
MSCCLPPCHECK(getDeviceNumaNode(comm->cudaDev, &comm->devNumaNode));
|
||||
INFO(MSCCLPP_INIT, "NUMA node of device %d is set to %d", comm->cudaDev, comm->devNumaNode);
|
||||
}
|
||||
// save numa node bitmask to change it back to user's numa node
|
||||
mscclppNumaState curProcessState;
|
||||
MSCCLPPCHECK(getNumaState(&curProcessState));
|
||||
// change to device's numa node so that the following allocation are close to the device
|
||||
MSCCLPPCHECK(numaBind(comm->devNumaNode));
|
||||
|
||||
if (comm->nConns == MAXCONNECTIONS) {
|
||||
WARN("Too many connections made");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
int connId = comm->nConns;
|
||||
struct mscclppConn* conn = &comm->conns[connId];
|
||||
conn->connId = connId;
|
||||
conn->transport = transportType;
|
||||
conn->buffSize = 0;
|
||||
|
||||
conn->ibCtx = NULL;
|
||||
int ibDevIdx = -1;
|
||||
if (transportType == mscclppTransportIB) {
|
||||
// Check if an IB context exists
|
||||
int firstNullIdx = -1;
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i] == NULL) {
|
||||
if (firstNullIdx == -1) {
|
||||
firstNullIdx = i;
|
||||
}
|
||||
} else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
ibDevIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If not, create a new one
|
||||
if (ibDevIdx == -1) {
|
||||
// Create a new context.
|
||||
ibDevIdx = firstNullIdx;
|
||||
comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev)));
|
||||
}
|
||||
// Set the ib context for this conn
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx].get();
|
||||
|
||||
} else if (transportType == mscclppTransportP2P) {
|
||||
// do the rest of the initialization later
|
||||
} else if (transportType == mscclppTransportSHM) {
|
||||
WARN("Shared memory interconnection is not implemented yet!");
|
||||
return mscclppInternalError;
|
||||
} else {
|
||||
WARN("Unexpected connection type!");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
|
||||
// Find/create a proxy state for the given connection
|
||||
struct mscclppProxyState* proxyState = NULL;
|
||||
// First see if there is a matching context
|
||||
// If not, find the first empty proxy
|
||||
int firstEmptyProxyIndex = -1;
|
||||
for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) {
|
||||
struct mscclppProxyState* curProxy = comm->proxyState[i];
|
||||
if (curProxy && (curProxy->transportType == transportType)) {
|
||||
if ((transportType == mscclppTransportIB && curProxy->ibContext == conn->ibCtx) ||
|
||||
(transportType == mscclppTransportP2P)) {
|
||||
proxyState = curProxy;
|
||||
break; // we found the matching context
|
||||
}
|
||||
}
|
||||
if (curProxy == NULL && firstEmptyProxyIndex == -1) {
|
||||
firstEmptyProxyIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (proxyState == NULL && firstEmptyProxyIndex == -1) {
|
||||
WARN("Too many proxies have been allocated!");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
|
||||
// If we couldn't find a matching context, create one
|
||||
if (proxyState == NULL) {
|
||||
MSCCLPPCHECK(mscclppCalloc(&proxyState, 1));
|
||||
MSCCLPPCHECK(proxyState->fifo.create());
|
||||
|
||||
if (transportType == mscclppTransportIB) {
|
||||
proxyState->ibContext = conn->ibCtx;
|
||||
proxyState->p2pStream = NULL;
|
||||
} else if (transportType == mscclppTransportP2P) {
|
||||
proxyState->ibContext = NULL;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&proxyState->p2pStream, cudaStreamNonBlocking));
|
||||
}
|
||||
proxyState->numaNodeToBind = comm->devNumaNode;
|
||||
|
||||
// INFO(MSCCLPP_INIT, "NUMA node for device %d is %d", cudaDev, *numaNode);
|
||||
proxyState->transportType = transportType;
|
||||
comm->proxyState[firstEmptyProxyIndex] = proxyState;
|
||||
}
|
||||
if (proxyState == NULL) {
|
||||
// Cannot reach
|
||||
WARN("Proxy allocation failed!");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
if (transportType == mscclppTransportIB) {
|
||||
conn->hostConn = new mscclppHostIBConn(conn);
|
||||
} else if (transportType == mscclppTransportP2P) {
|
||||
conn->hostConn = new mscclppHostP2PConn(conn, proxyState->p2pStream);
|
||||
}
|
||||
|
||||
struct mscclppDevConn* devConn = &comm->devConns[connId];
|
||||
|
||||
conn->devConn = devConn;
|
||||
conn->devConn->localBuff = nullptr;
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->localSignalEpochId, 1));
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1));
|
||||
conn->devConn->remoteRank = remoteRank;
|
||||
conn->devConn->tag = tag;
|
||||
conn->devConn->fifo.connId = connId;
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev;
|
||||
#else
|
||||
conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifo;
|
||||
#endif
|
||||
conn->devConn->fifo.triggerFifoHead = proxyState->fifo.fifoHead;
|
||||
conn->devConn->fifo.triggerFifoTail = proxyState->fifo.fifoTailDev;
|
||||
|
||||
comm->nConns++;
|
||||
|
||||
// change the numa binding back to user's
|
||||
MSCCLPPCHECK(setNumaState(curProcessState));
|
||||
|
||||
mscclppBufferHandle_t signalHandle = -1;
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId,
|
||||
sizeof(mscclppDevConnSignalEpochId), &signalHandle));
|
||||
if (signalHandle != 0) {
|
||||
WARN("signal handle should be 0");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff,
|
||||
uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev)
|
||||
{
|
||||
int connId = comm->nConns;
|
||||
MSCCLPPCHECK(mscclppConnectWithoutBuffer(comm, remoteRank, tag, transportType, ibDev));
|
||||
struct mscclppConn* conn = &comm->conns[connId];
|
||||
|
||||
conn->buffSize = buffSize;
|
||||
conn->devConn->localBuff = localBuff;
|
||||
|
||||
mscclppBufferHandle_t localBuffHandle = -1;
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, localBuff, buffSize, &localBuffHandle));
|
||||
if (localBuffHandle != 1) {
|
||||
WARN("data buffer handle should be 1");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff,
|
||||
uint64_t buffSize, mscclppBufferHandle_t* handle)
|
||||
{
|
||||
if (connIdx >= comm->nConns) {
|
||||
WARN("connIdx out of range");
|
||||
return mscclppInvalidArgument;
|
||||
}
|
||||
mscclppConn& conn = comm->conns[connIdx];
|
||||
*handle = conn.bufferRegistrations.size();
|
||||
conn.bufferRegistrations.emplace_back();
|
||||
conn.bufferRegistrations.back().data = localBuff;
|
||||
conn.bufferRegistrations.back().size = buffSize;
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct mscclppBufferRegistrationInfo
|
||||
{
|
||||
cudaIpcMemHandle_t cudaHandle;
|
||||
mscclpp::IbMrInfo ibMrInfo;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct connInfo
|
||||
{
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
|
||||
|
||||
struct header
|
||||
{
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
int numBufferInfos;
|
||||
};
|
||||
|
||||
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag)
|
||||
{
|
||||
header h;
|
||||
h.infoQp = infoQp;
|
||||
h.numBufferInfos = bufferInfos.size();
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(),
|
||||
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag)
|
||||
{
|
||||
header h;
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
infoQp = h.infoQp;
|
||||
bufferInfos.resize(h.numBufferInfos);
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(),
|
||||
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*input*/)
|
||||
{
|
||||
if (conn == NULL) {
|
||||
WARN("connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto& bufReg : conn->bufferRegistrations) {
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data));
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*output*/)
|
||||
{
|
||||
if (connInfo == NULL || conn == NULL) {
|
||||
WARN("ipcHandles or connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
if (connInfo->bufferInfos.size() < 1) {
|
||||
WARN("at least 1 buffer info expected");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Open all remote registered buffers
|
||||
for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) {
|
||||
mscclppBufferRegistration newBufReg;
|
||||
CUDACHECK(
|
||||
cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
newBufReg.size = connInfo->bufferInfos[i].size;
|
||||
conn->remoteBufferRegistrations.push_back(newBufReg);
|
||||
}
|
||||
|
||||
if (conn->remoteBufferRegistrations[0].size != sizeof(mscclppDevConnSignalEpochId)) {
|
||||
WARN("buffer registration zero size doesn't match sizeof(mscclppDevConnSignalEpochId)");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data;
|
||||
|
||||
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote
|
||||
// buffer to the first remote data buffer
|
||||
if (conn->remoteBufferRegistrations.size() > 1) {
|
||||
conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data;
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/)
|
||||
{
|
||||
if (connInfo == NULL || conn == NULL) {
|
||||
WARN("connInfo or connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppDevConn* devConn = conn->devConn;
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
devConn->remoteBuff = NULL;
|
||||
devConn->remoteSignalEpochId = NULL;
|
||||
|
||||
mscclpp::IbCtx* ibCtx = conn->ibCtx;
|
||||
if (hostConn->ibQp == NULL) {
|
||||
hostConn->ibQp = ibCtx->createQp();
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto& bufReg : conn->bufferRegistrations) {
|
||||
hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId)));
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo();
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
|
||||
connInfo->infoQp = hostConn->ibQp->getInfo();
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*output*/)
|
||||
{
|
||||
if (connInfo == NULL || conn == NULL) {
|
||||
WARN("ipcHandles or connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
hostConn->ibQp->rtr(connInfo->infoQp);
|
||||
hostConn->ibQp->rts();
|
||||
|
||||
// No remote pointers to set with IB, so we just set the Mrs
|
||||
|
||||
// Push the Mrs for all the remote registered buffers
|
||||
for (size_t i = 1; i < connInfo->bufferInfos.size(); i++) {
|
||||
hostConn->remoteIbMrInfos.push_back(connInfo->bufferInfos[i].ibMrInfo);
|
||||
|
||||
mscclppBufferRegistration newBufReg;
|
||||
newBufReg.data = nullptr;
|
||||
newBufReg.size = connInfo->bufferInfos[i].size;
|
||||
conn->remoteBufferRegistrations.push_back(newBufReg);
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
{
|
||||
// Send info to peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
|
||||
struct connInfo cInfo;
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
MSCCLPPCHECK(mscclppP2pConnectionSetupStart(&cInfo, conn));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn));
|
||||
}
|
||||
// TODO: from saemal: do we possibly deadlock if there are too many outstanding sends?
|
||||
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo,
|
||||
// sizeof(cInfo)));
|
||||
MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
|
||||
}
|
||||
|
||||
// Recv info from peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct connInfo cInfo;
|
||||
MSCCLPPCHECK(cInfo.recvOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
MSCCLPPCHECK(mscclppP2pConnectionSetupEnd(&cInfo, conn));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
MSCCLPPCHECK(mscclppIbConnectionSetupEnd(&cInfo, conn));
|
||||
}
|
||||
}
|
||||
|
||||
// a barrier to ensure setup on all gpus are done and we can return to the user
|
||||
MSCCLPPCHECK(mscclppBootstrapBarrier(comm));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct bufferInfo
|
||||
{
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
mscclpp::IbMrInfo infoBuffMr;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem)
|
||||
{
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
const mscclpp::IbMr* ibBuffMr;
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
ibBuffMr = conn->ibCtx->registerMr(local_memory, size);
|
||||
bInfo.infoBuffMr = ibBuffMr->getInfo();
|
||||
ibMrs.emplace_back(ibBuffMr);
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
}
|
||||
|
||||
// Recv info from peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
|
||||
mscclppRegisteredMemoryP2P p2p;
|
||||
p2p.IbMr = NULL;
|
||||
p2p.remoteBuff = NULL;
|
||||
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&p2p.remoteBuff, bInfo.handleBuff, cudaIpcMemLazyEnablePeerAccess));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
p2p.IbMr = ibMrs[i];
|
||||
}
|
||||
regMem->p2p.push_back(p2p);
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem,
|
||||
void* srcBuff, size_t size, uint32_t srcOffset,
|
||||
uint32_t dstOffset, int64_t stream)
|
||||
{
|
||||
int ret = 0;
|
||||
// TODO: transport should be an argument too so user can decide which transport to use
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
void* dstBuff = regMem->p2p[i].remoteBuff;
|
||||
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
|
||||
} else {
|
||||
WARN("mscclppRegisteredBufferWrite not implemented for IB");
|
||||
return mscclppInternalError;
|
||||
// TODO: fix the following (Olli: probably by including the relevant ibBuffMr in the mscclppRegisteredMemory)
|
||||
// struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
// hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrRemoteInfo, (uint32_t)size,
|
||||
// /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false);
|
||||
// if ((ret = hostConn->ibQp->postSend()) != 0) {
|
||||
// // Return value is errno.
|
||||
// WARN("data postSend failed: errno %d", ret);
|
||||
// }
|
||||
// // ??
|
||||
// // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
|
||||
// // trigger.fields.connId);
|
||||
}
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// TODO: destroy registered buffer
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
|
||||
{
|
||||
npkitInitReqIds(comm);
|
||||
MSCCLPPCHECK(mscclppProxyCreate(comm));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm)
|
||||
{
|
||||
int* tmp = new int[comm->nRanks];
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
delete[] tmp;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppProxyStop(mscclppComm_t comm)
|
||||
{
|
||||
// a barrier to make sure all ranks are done with their work before stopping the proxy
|
||||
MSCCLPPCHECK(mscclppBootstrapBarrier(comm));
|
||||
|
||||
MSCCLPPCHECK(mscclppProxyDestroy(comm));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank)
|
||||
{
|
||||
if (comm == NULL || rank == NULL) {
|
||||
WARN("comm or rank cannot be null");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
*rank = comm->rank;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
|
||||
{
|
||||
if (comm == NULL || size == NULL) {
|
||||
WARN("comm or size cannot be null");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
*size = comm->nRanks;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API void mscclppDefaultLogHandler(const char* msg)
|
||||
{
|
||||
mscclppDebugDefaultLogHandler(msg);
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler)
|
||||
{
|
||||
return mscclppDebugSetLogHandler(handler);
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout)
|
||||
{
|
||||
mscclppConfig* config = mscclppConfig::getInstance();
|
||||
config->setBootstrapConnectionTimeoutConfig(timeout);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include "npkit.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "npkit/npkit.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
uint64_t NpKit::rank_ = 0;
|
||||
|
||||
@@ -16,8 +18,7 @@ NpKitEventCollectContext* NpKit::cpu_collect_contexts_ = nullptr;
|
||||
uint64_t NpKit::cpu_base_system_timestamp_ = 0;
|
||||
uint64_t NpKit::cpu_base_steady_timestamp_ = 0;
|
||||
|
||||
mscclppResult_t NpKit::Init(int rank)
|
||||
{
|
||||
mscclppResult_t NpKit::Init(int rank) {
|
||||
uint64_t i = 0;
|
||||
NpKitEventCollectContext ctx;
|
||||
ctx.event_buffer_head = 0;
|
||||
@@ -47,8 +48,7 @@ mscclppResult_t NpKit::Init(int rank)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t NpKit::Dump(const std::string& dump_dir)
|
||||
{
|
||||
mscclppResult_t NpKit::Dump(const std::string& dump_dir) {
|
||||
uint64_t i = 0;
|
||||
std::string dump_file_path;
|
||||
|
||||
@@ -113,8 +113,7 @@ mscclppResult_t NpKit::Dump(const std::string& dump_dir)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t NpKit::Shutdown()
|
||||
{
|
||||
mscclppResult_t NpKit::Shutdown() {
|
||||
uint64_t i = 0;
|
||||
|
||||
// Free CPU event data structures
|
||||
@@ -134,13 +133,9 @@ mscclppResult_t NpKit::Shutdown()
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts()
|
||||
{
|
||||
return gpu_collect_contexts_;
|
||||
}
|
||||
NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts() { return gpu_collect_contexts_; }
|
||||
|
||||
void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id)
|
||||
{
|
||||
void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id) {
|
||||
uint64_t event_buffer_head = cpu_collect_contexts_[channel_id].event_buffer_head;
|
||||
if (event_buffer_head < kMaxNumCpuEventsPerBuffer) {
|
||||
NpKitEvent& event = cpu_collect_contexts_[channel_id].event_buffer[event_buffer_head];
|
||||
@@ -152,8 +147,7 @@ void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t NpKit::GetCpuTimestamp()
|
||||
{
|
||||
uint64_t NpKit::GetCpuTimestamp() {
|
||||
uint64_t cpu_curr_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
return cpu_base_steady_timestamp_ + (cpu_curr_steady_timestamp_ - cpu_base_steady_timestamp_);
|
||||
}
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "npkit/npkit_event.h"
|
||||
#include "npkit/npkit_struct.h"
|
||||
#include "mscclpp.h"
|
||||
#include "npkit_event.h"
|
||||
#include "npkit_struct.h"
|
||||
|
||||
class NpKit
|
||||
{
|
||||
public:
|
||||
class NpKit {
|
||||
public:
|
||||
static const uint64_t kNumGpuEventBuffers = 512;
|
||||
|
||||
static const uint64_t kNumCpuEventBuffers = 32;
|
||||
@@ -21,9 +21,9 @@ public:
|
||||
|
||||
static NpKitEventCollectContext* GetGpuEventCollectContexts();
|
||||
|
||||
#ifdef __CUDACC__
|
||||
static inline __device__ void CollectGpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp,
|
||||
NpKitEventCollectContext* ctx)
|
||||
{
|
||||
NpKitEventCollectContext* ctx) {
|
||||
uint64_t event_buffer_head = ctx->event_buffer_head;
|
||||
if (event_buffer_head < kMaxNumGpuEventsPerBuffer) {
|
||||
NpKitEvent& event = ctx->event_buffer[event_buffer_head];
|
||||
@@ -34,12 +34,13 @@ public:
|
||||
ctx->event_buffer_head++;
|
||||
}
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id);
|
||||
|
||||
static uint64_t GetCpuTimestamp();
|
||||
|
||||
private:
|
||||
private:
|
||||
// 64K * 512 * 16B = 512MB per GPU
|
||||
static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16;
|
||||
|
||||
@@ -7,8 +7,7 @@
|
||||
|
||||
union NpKitEvent {
|
||||
uint64_t bits[2];
|
||||
struct
|
||||
{
|
||||
struct {
|
||||
uint64_t type : 8;
|
||||
uint64_t size : 32;
|
||||
uint64_t rsvd : 24;
|
||||
@@ -16,8 +15,7 @@ union NpKitEvent {
|
||||
} fields;
|
||||
};
|
||||
|
||||
struct NpKitEventCollectContext
|
||||
{
|
||||
struct NpKitEventCollectContext {
|
||||
NpKitEvent* event_buffer;
|
||||
uint64_t event_buffer_head;
|
||||
};
|
||||
263
src/proxy.cc
263
src/proxy.cc
@@ -1,214 +1,101 @@
|
||||
#include "alloc.h"
|
||||
#include "checks.h"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.hpp"
|
||||
#include "socket.h"
|
||||
|
||||
#include <emmintrin.h>
|
||||
#include <map>
|
||||
#include <sys/syscall.h>
|
||||
#include <atomic>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <thread>
|
||||
|
||||
#define MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD 100
|
||||
#include "api.h"
|
||||
#include "utils.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#define PROXYCUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
WARN("CUDA error from proxy: %s", cudaGetErrorString(err)); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (false)
|
||||
namespace mscclpp {
|
||||
|
||||
#define PROXYMSCCLPPCHECK(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
if (mscclppDebugNoWarn == 0) \
|
||||
INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0);
|
||||
const int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
struct proxyArgs
|
||||
{
|
||||
struct mscclppComm* comm;
|
||||
struct mscclppProxyState* proxyState;
|
||||
const int ProxyFlushPeriod = 4;
|
||||
|
||||
struct Proxy::Impl {
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
HostProxyFifo fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit)
|
||||
: handler(handler), threadInit(threadInit), running(false) {}
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppProxyFifo::create()
|
||||
{
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1));
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
MSCCLPPCHECK(
|
||||
mscclppGdrCudaCalloc(&this->triggerFifo, &this->triggerFifoDev, MSCCLPP_PROXY_FIFO_SIZE, &this->triggerFifoDesc));
|
||||
MSCCLPPCHECK(mscclppGdrCudaCalloc(&this->fifoTailDevHostPtr, &this->fifoTailDev, 1, &this->fifoTailDesc));
|
||||
#else
|
||||
MSCCLPPCHECK(mscclppCudaHostCalloc(&this->triggerFifo, MSCCLPP_PROXY_FIFO_SIZE));
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoTailDev, 1));
|
||||
#endif
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&this->stream, cudaStreamNonBlocking));
|
||||
this->fifoTailHost = 0;
|
||||
return mscclppSuccess;
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit) {
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit);
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyFifo::destroy()
|
||||
{
|
||||
MSCCLPPCHECK(mscclppCudaFree(this->fifoHead));
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
MSCCLPPCHECK(mscclppGdrCudaFree(this->triggerFifoDesc));
|
||||
MSCCLPPCHECK(mscclppGdrCudaFree(this->fifoTailDesc));
|
||||
#else
|
||||
MSCCLPPCHECK(mscclppCudaHostFree(this->triggerFifo));
|
||||
MSCCLPPCHECK(mscclppCudaFree(this->fifoTailDev));
|
||||
#endif
|
||||
CUDACHECK(cudaStreamDestroy(this->stream));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) {}
|
||||
|
||||
// return true if the trigger is valid
|
||||
mscclppResult_t mscclppProxyFifo::poll(mscclppTrigger* trigger)
|
||||
{
|
||||
__m128i xmm0 = _mm_load_si128((__m128i*)&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
_mm_store_si128((__m128i*)trigger, xmm0);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyFifo::pop()
|
||||
{
|
||||
*(volatile uint64_t*)(&this->triggerFifo[this->fifoTailHost % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
(this->fifoTailHost)++;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyFifo::flushTail(bool sync)
|
||||
{
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
*(volatile uint64_t*)(this->fifoTailDevHostPtr) = this->fifoTailHost;
|
||||
#else
|
||||
CUDACHECK(
|
||||
cudaMemcpyAsync(this->fifoTailDev, &(this->fifoTailHost), sizeof(uint64_t), cudaMemcpyHostToDevice, this->stream));
|
||||
if (sync) {
|
||||
CUDACHECK(cudaStreamSynchronize(this->stream));
|
||||
}
|
||||
#endif
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
static void processTrigger(const mscclppTrigger trigger, mscclppConn* conn)
|
||||
{
|
||||
// Iterate over what send is needed
|
||||
if (trigger.fields.type & mscclppData) {
|
||||
conn->hostConn->put(trigger.fields.dstDataOffset, trigger.fields.srcDataOffset, trigger.fields.dataSize);
|
||||
}
|
||||
|
||||
if (trigger.fields.type & mscclppFlag) {
|
||||
conn->hostConn->signal();
|
||||
}
|
||||
|
||||
// Wait for completion
|
||||
if (trigger.fields.type & mscclppSync) {
|
||||
conn->hostConn->flush();
|
||||
MSCCLPP_API_CPP Proxy::~Proxy() {
|
||||
if (pimpl) {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
void* mscclppProxyService(void* _args)
|
||||
{
|
||||
struct proxyArgs* args = (struct proxyArgs*)_args;
|
||||
struct mscclppComm* comm = args->comm;
|
||||
struct mscclppProxyState* proxyState = args->proxyState;
|
||||
free(_args); // allocated in mscclppProxyCreate
|
||||
MSCCLPP_API_CPP void Proxy::start() {
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this] {
|
||||
pimpl->threadInit();
|
||||
|
||||
// from this point on, proxy thread will stay close to the device
|
||||
PROXYCUDACHECK(cudaSetDevice(comm->cudaDev));
|
||||
PROXYMSCCLPPCHECK(numaBind(comm->devNumaNode));
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
HostProxyFifo& fifo = this->pimpl->fifo;
|
||||
std::atomic_bool& running = this->pimpl->running;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
struct mscclppProxyFifo* fifo = &proxyState->fifo;
|
||||
volatile mscclppProxyRunState_t* run = &proxyState->run;
|
||||
mscclppTrigger trigger;
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
if (runCnt-- == 0) {
|
||||
runCnt = ProxyStopCheckPeriod;
|
||||
if (!running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
fifo.poll(&trigger);
|
||||
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
}
|
||||
|
||||
int runCnt = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
if (runCnt-- == 0) {
|
||||
runCnt = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD;
|
||||
if (*run != MSCCLPP_PROXY_RUN_STATE_RUNNING) {
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
fifo.pop();
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
}
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
PROXYMSCCLPPCHECK(fifo->poll(&trigger));
|
||||
if (trigger.value[0] == 0) {
|
||||
continue; // there is one in progreess
|
||||
}
|
||||
|
||||
mscclppConn* conn = &comm->conns[trigger.fields.connId];
|
||||
processTrigger(trigger, conn);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
PROXYMSCCLPPCHECK(fifo->pop());
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if (((++flushCnt % MSCCLPP_PROXY_FIFO_FLUSH_COUNTER) == 0) || (trigger.fields.type & mscclppSync)) {
|
||||
PROXYMSCCLPPCHECK(fifo->flushTail());
|
||||
}
|
||||
}
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
PROXYMSCCLPPCHECK(fifo->flushTail(/*sync=*/true));
|
||||
bool isP2pProxy = (proxyState->ibContext == nullptr);
|
||||
if (isP2pProxy) {
|
||||
cudaStream_t p2pStream = proxyState->p2pStream;
|
||||
PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
}
|
||||
*run = MSCCLPP_PROXY_RUN_STATE_IDLE;
|
||||
return NULL;
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
fifo.flushTail(/*sync=*/true);
|
||||
// TODO: do these need to run?
|
||||
// bool isP2pProxy = (proxyState->ibContext == nullptr);
|
||||
// if (isP2pProxy) {
|
||||
// cudaStream_t p2pStream = proxyState->p2pStream;
|
||||
// PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm)
|
||||
{
|
||||
for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) {
|
||||
struct mscclppProxyState* proxyState = comm->proxyState[i];
|
||||
if (proxyState == NULL)
|
||||
break;
|
||||
|
||||
struct proxyArgs* args;
|
||||
MSCCLPPCHECK(mscclppCalloc(&args, 1));
|
||||
args->comm = comm;
|
||||
args->proxyState = proxyState;
|
||||
|
||||
proxyState->run = MSCCLPP_PROXY_RUN_STATE_RUNNING;
|
||||
pthread_create(&proxyState->thread, NULL, mscclppProxyService, args);
|
||||
if (proxyState->transportType == mscclppTransportP2P) {
|
||||
mscclppSetThreadName(proxyState->thread, "MSCCLPP Service P2P - %02d", comm->cudaDev);
|
||||
} else if (proxyState->transportType == mscclppTransportIB) {
|
||||
mscclppSetThreadName(proxyState->thread, "MSCCLPP Service IB - %02d", i);
|
||||
}
|
||||
MSCCLPP_API_CPP void Proxy::stop() {
|
||||
pimpl->running = false;
|
||||
if (pimpl->service.joinable()) {
|
||||
pimpl->service.join();
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm)
|
||||
{
|
||||
for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) {
|
||||
struct mscclppProxyState* proxyState = comm->proxyState[i];
|
||||
if (proxyState == NULL)
|
||||
break;
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { return pimpl->fifo; }
|
||||
|
||||
volatile int* run = (volatile int*)&proxyState->run;
|
||||
if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) {
|
||||
continue;
|
||||
}
|
||||
*run = MSCCLPP_PROXY_RUN_STATE_EXITING;
|
||||
while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) {
|
||||
usleep(1000);
|
||||
}
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
} // namespace mscclpp
|
||||
|
||||
112
src/proxy_cpp.cc
112
src/proxy_cpp.cc
@@ -1,112 +0,0 @@
|
||||
#include "api.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "utils.h"
|
||||
#include "utils.hpp"
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
const int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
const int ProxyFlushPeriod = 4;
|
||||
|
||||
struct Proxy::Impl
|
||||
{
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
HostProxyFifo fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit)
|
||||
: handler(handler), threadInit(threadInit), running(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit)
|
||||
{
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {})
|
||||
{
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy()
|
||||
{
|
||||
if (pimpl) {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::start()
|
||||
{
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this] {
|
||||
pimpl->threadInit();
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
HostProxyFifo& fifo = this->pimpl->fifo;
|
||||
std::atomic_bool& running = this->pimpl->running;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
if (runCnt-- == 0) {
|
||||
runCnt = ProxyStopCheckPeriod;
|
||||
if (!running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
fifo.poll(&trigger);
|
||||
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
}
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
fifo.pop();
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
}
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
fifo.flushTail(/*sync=*/true);
|
||||
// TODO: do these need to run?
|
||||
// bool isP2pProxy = (proxyState->ibContext == nullptr);
|
||||
// if (isP2pProxy) {
|
||||
// cudaStream_t p2pStream = proxyState->p2pStream;
|
||||
// PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::stop()
|
||||
{
|
||||
pimpl->running = false;
|
||||
if (pimpl->service.joinable()) {
|
||||
pimpl->service.join();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo()
|
||||
{
|
||||
return pimpl->fifo;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -1,22 +1,24 @@
|
||||
#include "registered_memory.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "utils.h"
|
||||
#include <algorithm>
|
||||
#include <cuda.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
|
||||
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports)
|
||||
{
|
||||
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) {
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = Transport::CudaIpc;
|
||||
cudaIpcMemHandle_t handle;
|
||||
|
||||
void* baseDataPtr;
|
||||
size_t baseDataSize; // dummy
|
||||
size_t baseDataSize; // dummy
|
||||
CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
|
||||
CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
|
||||
// TODO: bug with offset of base?
|
||||
@@ -35,60 +37,37 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size);
|
||||
};
|
||||
if (transports.has(Transport::IB0))
|
||||
addIb(Transport::IB0);
|
||||
if (transports.has(Transport::IB1))
|
||||
addIb(Transport::IB1);
|
||||
if (transports.has(Transport::IB2))
|
||||
addIb(Transport::IB2);
|
||||
if (transports.has(Transport::IB3))
|
||||
addIb(Transport::IB3);
|
||||
if (transports.has(Transport::IB4))
|
||||
addIb(Transport::IB4);
|
||||
if (transports.has(Transport::IB5))
|
||||
addIb(Transport::IB5);
|
||||
if (transports.has(Transport::IB6))
|
||||
addIb(Transport::IB6);
|
||||
if (transports.has(Transport::IB7))
|
||||
addIb(Transport::IB7);
|
||||
if (transports.has(Transport::IB0)) addIb(Transport::IB0);
|
||||
if (transports.has(Transport::IB1)) addIb(Transport::IB1);
|
||||
if (transports.has(Transport::IB2)) addIb(Transport::IB2);
|
||||
if (transports.has(Transport::IB3)) addIb(Transport::IB3);
|
||||
if (transports.has(Transport::IB4)) addIb(Transport::IB4);
|
||||
if (transports.has(Transport::IB5)) addIb(Transport::IB5);
|
||||
if (transports.has(Transport::IB6)) addIb(Transport::IB6);
|
||||
if (transports.has(Transport::IB7)) addIb(Transport::IB7);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl)
|
||||
{
|
||||
}
|
||||
MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl) {}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
|
||||
|
||||
MSCCLPP_API_CPP void* RegisteredMemory::data()
|
||||
{
|
||||
return pimpl->data;
|
||||
}
|
||||
MSCCLPP_API_CPP void* RegisteredMemory::data() { return pimpl->data; }
|
||||
|
||||
MSCCLPP_API_CPP size_t RegisteredMemory::size()
|
||||
{
|
||||
return pimpl->size;
|
||||
}
|
||||
MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; }
|
||||
|
||||
MSCCLPP_API_CPP int RegisteredMemory::rank()
|
||||
{
|
||||
return pimpl->rank;
|
||||
}
|
||||
MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; }
|
||||
|
||||
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports()
|
||||
{
|
||||
return pimpl->transports;
|
||||
}
|
||||
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; }
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
||||
{
|
||||
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
|
||||
std::vector<char> result;
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
|
||||
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
||||
throw mscclpp::Error("Too many transport info entries", mscclppInternalError);
|
||||
throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError);
|
||||
}
|
||||
int8_t transportCount = pimpl->transportInfos.size();
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
@@ -102,19 +81,17 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
||||
} else if (AllIBTransports.has(entry.transport)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
} else {
|
||||
throw mscclpp::Error("Unknown transport", mscclppInternalError);
|
||||
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data)
|
||||
{
|
||||
MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data) {
|
||||
return RegisteredMemory(std::make_shared<Impl>(data));
|
||||
}
|
||||
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
{
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
|
||||
it += sizeof(this->size);
|
||||
@@ -143,12 +120,12 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
it += sizeof(transportInfo.ibMrInfo);
|
||||
transportInfo.ibLocal = false;
|
||||
} else {
|
||||
throw mscclpp::Error("Unknown transport", mscclppInternalError);
|
||||
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
|
||||
}
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (it != serialization.end()) {
|
||||
throw mscclpp::Error("Serialization failed", mscclppInternalError);
|
||||
throw mscclpp::Error("Serialization failed", ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
@@ -163,4 +140,4 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
89
src/utils.cc
89
src/utils.cc
@@ -6,9 +6,10 @@
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <numa.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
// Get current Compute Capability
|
||||
@@ -21,20 +22,17 @@
|
||||
// return ccMajor*10+ccMinor;
|
||||
// }
|
||||
|
||||
mscclppResult_t int64ToBusId(int64_t id, char* busId)
|
||||
{
|
||||
mscclppResult_t int64ToBusId(int64_t id, char* busId) {
|
||||
sprintf(busId, "%04lx:%02lx:%02lx.%01lx", (id) >> 20, (id & 0xff000) >> 12, (id & 0xff0) >> 4, (id & 0xf));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t busIdToInt64(const char* busId, int64_t* id)
|
||||
{
|
||||
char hexStr[17]; // Longest possible int64 hex string + null terminator.
|
||||
mscclppResult_t busIdToInt64(const char* busId, int64_t* id) {
|
||||
char hexStr[17]; // Longest possible int64 hex string + null terminator.
|
||||
int hexOffset = 0;
|
||||
for (int i = 0; hexOffset < sizeof(hexStr) - 1; i++) {
|
||||
char c = busId[i];
|
||||
if (c == '.' || c == ':')
|
||||
continue;
|
||||
if (c == '.' || c == ':') continue;
|
||||
if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f')) {
|
||||
hexStr[hexOffset++] = busId[i];
|
||||
} else
|
||||
@@ -46,8 +44,7 @@ mscclppResult_t busIdToInt64(const char* busId, int64_t* id)
|
||||
}
|
||||
|
||||
// Convert a logical cudaDev index to the NVML device minor number
|
||||
mscclppResult_t getBusId(int cudaDev, std::string* busId)
|
||||
{
|
||||
mscclppResult_t getBusId(int cudaDev, std::string* busId) {
|
||||
// On most systems, the PCI bus ID comes back as in the 0000:00:00.0
|
||||
// format. Still need to allocate proper space in case PCI domain goes
|
||||
// higher.
|
||||
@@ -61,8 +58,7 @@ mscclppResult_t getBusId(int cudaDev, std::string* busId)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode)
|
||||
{
|
||||
mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode) {
|
||||
std::string busId;
|
||||
MSCCLPPCHECK(getBusId(cudaDev, &busId));
|
||||
|
||||
@@ -81,21 +77,18 @@ mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t getHostName(char* hostname, int maxlen, const char delim)
|
||||
{
|
||||
mscclppResult_t getHostName(char* hostname, int maxlen, const char delim) {
|
||||
if (gethostname(hostname, maxlen) != 0) {
|
||||
strncpy(hostname, "unknown", maxlen);
|
||||
return mscclppSystemError;
|
||||
}
|
||||
int i = 0;
|
||||
while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1))
|
||||
i++;
|
||||
while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1)) i++;
|
||||
hostname[i] = '\0';
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
uint64_t getHash(const char* string, int n)
|
||||
{
|
||||
uint64_t getHash(const char* string, int n) {
|
||||
// Based on DJB2a, result = result * 33 ^ char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; c < n; c++) {
|
||||
@@ -113,8 +106,7 @@ uint64_t getHash(const char* string, int n)
|
||||
* This string can be overridden by using the MSCCLPP_HOSTID env var.
|
||||
*/
|
||||
#define HOSTID_FILE "/proc/sys/kernel/random/boot_id"
|
||||
uint64_t computeHostHash(void)
|
||||
{
|
||||
uint64_t computeHostHash(void) {
|
||||
char hostHash[1024];
|
||||
char* hostId;
|
||||
|
||||
@@ -145,8 +137,7 @@ uint64_t computeHostHash(void)
|
||||
return getHash(hostHash, strlen(hostHash));
|
||||
}
|
||||
|
||||
uint64_t getHostHash(void)
|
||||
{
|
||||
uint64_t getHostHash(void) {
|
||||
thread_local std::unique_ptr<uint64_t> hostHash = std::make_unique<uint64_t>(computeHostHash());
|
||||
return *hostHash;
|
||||
}
|
||||
@@ -157,15 +148,13 @@ uint64_t getHostHash(void)
|
||||
*
|
||||
* $$ $(readlink /proc/self/ns/pid)
|
||||
*/
|
||||
uint64_t getPidHash(void)
|
||||
{
|
||||
uint64_t getPidHash(void) {
|
||||
char pname[1024];
|
||||
// Start off with our pid ($$)
|
||||
sprintf(pname, "%ld", (long)getpid());
|
||||
int plen = strlen(pname);
|
||||
int len = readlink("/proc/self/ns/pid", pname + plen, sizeof(pname) - 1 - plen);
|
||||
if (len < 0)
|
||||
len = 0;
|
||||
if (len < 0) len = 0;
|
||||
|
||||
pname[plen + len] = '\0';
|
||||
TRACE(MSCCLPP_INIT, "unique PID '%s'", pname);
|
||||
@@ -173,10 +162,8 @@ uint64_t getPidHash(void)
|
||||
return getHash(pname, strlen(pname));
|
||||
}
|
||||
|
||||
int parseStringList(const char* string, struct netIf* ifList, int maxList)
|
||||
{
|
||||
if (!string)
|
||||
return 0;
|
||||
int parseStringList(const char* string, struct netIf* ifList, int maxList) {
|
||||
if (!string) return 0;
|
||||
|
||||
const char* ptr = string;
|
||||
|
||||
@@ -192,8 +179,7 @@ int parseStringList(const char* string, struct netIf* ifList, int maxList)
|
||||
ifNum++;
|
||||
ifC = 0;
|
||||
}
|
||||
while (c != ',' && c != '\0')
|
||||
c = *(++ptr);
|
||||
while (c != ',' && c != '\0') c = *(++ptr);
|
||||
} else if (c == ',' || c == '\0') {
|
||||
if (ifC > 0) {
|
||||
ifList[ifNum].prefix[ifC] = '\0';
|
||||
@@ -210,29 +196,22 @@ int parseStringList(const char* string, struct netIf* ifList, int maxList)
|
||||
return ifNum;
|
||||
}
|
||||
|
||||
static bool matchIf(const char* string, const char* ref, bool matchExact)
|
||||
{
|
||||
static bool matchIf(const char* string, const char* ref, bool matchExact) {
|
||||
// Make sure to include '\0' in the exact case
|
||||
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
|
||||
return strncmp(string, ref, matchLen) == 0;
|
||||
}
|
||||
|
||||
static bool matchPort(const int port1, const int port2)
|
||||
{
|
||||
if (port1 == -1)
|
||||
return true;
|
||||
if (port2 == -1)
|
||||
return true;
|
||||
if (port1 == port2)
|
||||
return true;
|
||||
static bool matchPort(const int port1, const int port2) {
|
||||
if (port1 == -1) return true;
|
||||
if (port2 == -1) return true;
|
||||
if (port1 == port2) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact)
|
||||
{
|
||||
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
|
||||
// Make an exception for the case where no user list is defined
|
||||
if (listSize == 0)
|
||||
return true;
|
||||
if (listSize == 0) return true;
|
||||
|
||||
for (int i = 0; i < listSize; i++) {
|
||||
if (matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
|
||||
@@ -242,8 +221,7 @@ bool matchIfList(const char* string, int port, struct netIf* ifList, int listSiz
|
||||
return false;
|
||||
}
|
||||
|
||||
mscclppResult_t numaBind(int node)
|
||||
{
|
||||
mscclppResult_t numaBind(int node) {
|
||||
int totalNumNumaNodes = numa_num_configured_nodes();
|
||||
if (node < 0 || node >= totalNumNumaNodes) {
|
||||
WARN("Invalid NUMA node %d, must be between 0 and %d", node, totalNumNumaNodes);
|
||||
@@ -256,9 +234,7 @@ mscclppResult_t numaBind(int node)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t getNumaState(mscclppNumaState* state)
|
||||
{
|
||||
|
||||
mscclppResult_t getNumaState(mscclppNumaState* state) {
|
||||
mscclppNumaState state_ = numa_get_run_node_mask();
|
||||
if (state_ == NULL) {
|
||||
WARN("Failed to get NUMA node mask of the running process");
|
||||
@@ -268,8 +244,7 @@ mscclppResult_t getNumaState(mscclppNumaState* state)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t setNumaState(mscclppNumaState state)
|
||||
{
|
||||
mscclppResult_t setNumaState(mscclppNumaState state) {
|
||||
if (state == NULL) {
|
||||
WARN("Invalid NUMA state");
|
||||
return mscclppInvalidUsage;
|
||||
@@ -278,12 +253,8 @@ mscclppResult_t setNumaState(mscclppNumaState state)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppTime_t getClock()
|
||||
{
|
||||
return std::chrono::steady_clock::now();
|
||||
}
|
||||
mscclppTime_t getClock() { return std::chrono::steady_clock::now(); }
|
||||
|
||||
int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end)
|
||||
{
|
||||
int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end) {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(end - start).count();
|
||||
}
|
||||
|
||||
20
test/CMakeLists.txt
Normal file
20
test/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
function(add_test_executable name sources)
|
||||
add_executable(${name} ${sources})
|
||||
target_link_libraries(${name} mscclpp CUDA::cudart CUDA::cuda_driver)
|
||||
target_include_directories(${name} PRIVATE ${PROJECT_SOURCE_DIR}/src/include)
|
||||
if(USE_MPI_FOR_TESTS)
|
||||
target_link_libraries(${name} MPI::MPI_CXX)
|
||||
target_compile_definitions(${name} PRIVATE MSCCLPP_USE_MPI_FOR_TESTS)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
add_test_executable(bootstrap_test_cpp bootstrap_test_cpp.cc)
|
||||
add_test_executable(communicator_test_cpp communicator_test_cpp.cu)
|
||||
add_test_executable(allgather_test_cpp allgather_test_cpp.cu)
|
||||
add_test_executable(ib_test ib_test.cc)
|
||||
|
||||
# Unit tests
|
||||
add_executable(unit_tests)
|
||||
target_link_libraries(unit_tests GTest::gtest_main GTest::gmock_main mscclpp)
|
||||
add_subdirectory(unit) # This adds the sources to the mscclpp target
|
||||
gtest_discover_tests(unit_tests DISCOVERY_MODE PRE_TEST)
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include "channel.hpp"
|
||||
#include <mscclpp/channel.hpp>
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include "mpi.h"
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "epoch.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include <mscclpp/epoch.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
@@ -88,12 +88,16 @@ void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr)
|
||||
bool test_device_buffer_write_correctness(int rank, int worldSize, int nRanksPerNode, int dataCount,
|
||||
std::vector<int*>& devicePtr, bool skipLocal = false)
|
||||
{
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++) {
|
||||
std::vector<int> hostBuffer(dataCount, 0);
|
||||
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount * sizeof(int), cudaMemcpyDeviceToHost));
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i / nRanksPerNode == rank / nRanksPerNode && skipLocal) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i * dataCount / worldSize; j < (i + 1) * dataCount / worldSize; j++) {
|
||||
if (hostBuffer[j] != i + n * worldSize) {
|
||||
return false;
|
||||
@@ -104,7 +108,8 @@ bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vec
|
||||
return true;
|
||||
}
|
||||
|
||||
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
void test_write(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize,
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers)
|
||||
@@ -129,7 +134,7 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
|
||||
bool ready = false;
|
||||
int niter = 0;
|
||||
do {
|
||||
ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr);
|
||||
ready = test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr);
|
||||
niter++;
|
||||
if (niter == 10000) {
|
||||
throw std::runtime_error("Polling is stuck.");
|
||||
@@ -139,9 +144,12 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
|
||||
}
|
||||
|
||||
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
|
||||
__global__ void increament_epochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize)
|
||||
{
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize) {
|
||||
@@ -149,7 +157,7 @@ __global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
|
||||
__global__ void wait_epochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize)
|
||||
{
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize) {
|
||||
@@ -157,14 +165,25 @@ __global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int wo
|
||||
}
|
||||
}
|
||||
|
||||
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers)
|
||||
void test_write_with_device_epochs(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize,
|
||||
mscclpp::Communicator& communicator,
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr,
|
||||
int numBuffers)
|
||||
{
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::DeviceEpoch>> epochs;
|
||||
for (auto entry : connections) {
|
||||
auto& conn = entry.second;
|
||||
epochs.insert({entry.first, std::make_shared<mscclpp::DeviceEpoch>(communicator, conn)});
|
||||
}
|
||||
communicator.setup();
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
|
||||
@@ -173,12 +192,13 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA memory initialization passed" << std::endl;
|
||||
|
||||
mscclpp::DeviceEpoch* deviceEpochs;
|
||||
CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize));
|
||||
mscclpp::DeviceEpoch::DeviceHandle* deviceEpochHandles;
|
||||
CUDATHROW(cudaMalloc(&deviceEpochHandles, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * worldSize));
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch();
|
||||
CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice));
|
||||
mscclpp::DeviceEpoch::DeviceHandle deviceHandle = epochs[i]->deviceHandle();
|
||||
CUDATHROW(cudaMemcpy(&deviceEpochHandles[i], &deviceHandle, sizeof(mscclpp::DeviceEpoch::DeviceHandle),
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
@@ -191,7 +211,7 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
|
||||
increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
increament_epochs<<<1, worldSize>>>(deviceEpochHandles, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
@@ -200,20 +220,78 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
|
||||
}
|
||||
}
|
||||
|
||||
wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
wait_epochs<<<1, worldSize>>>(deviceEpochHandles, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)) {
|
||||
if (!test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr)) {
|
||||
throw std::runtime_error("unexpected result.");
|
||||
}
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---"
|
||||
std::cout << "--- Testing writes with device epochs for " << std::to_string(numBuffers) << " buffers passed ---"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
void test_write_with_host_epochs(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize,
|
||||
mscclpp::Communicator& communicator, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr,
|
||||
int numBuffers)
|
||||
{
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::HostEpoch>> epochs;
|
||||
for (auto entry : connections) {
|
||||
auto& conn = entry.second;
|
||||
if (conn->transport() == mscclpp::Transport::CudaIpc)
|
||||
continue;
|
||||
epochs.insert({entry.first, std::make_shared<mscclpp::HostEpoch>(communicator, conn)});
|
||||
}
|
||||
communicator.setup();
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
|
||||
device_buffer_init(rank, worldSize, dataCount, devicePtr);
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA memory initialization passed" << std::endl;
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Host epochs are created" << std::endl;
|
||||
|
||||
for (int n = 0; n < numBuffers; n++) {
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
epochs[i]->increamentAndSignal();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
epochs[i]->wait();
|
||||
}
|
||||
}
|
||||
|
||||
if (!test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr, true)) {
|
||||
throw std::runtime_error("unexpected result.");
|
||||
}
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing writes with host epochs for " << std::to_string(numBuffers) << " buffers passed ---"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nRanksPerNode)
|
||||
{
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
@@ -227,9 +305,9 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
std::cout << "Communicator initialization passed" << std::endl;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> connections;
|
||||
auto myIbDevice = findIb(rank % nranksPerNode);
|
||||
auto myIbDevice = findIb(rank % nRanksPerNode);
|
||||
|
||||
make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections);
|
||||
make_connections(communicator, rank, worldSize, nRanksPerNode, myIbDevice, connections);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Connection setup passed" << std::endl;
|
||||
|
||||
@@ -251,23 +329,14 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr,
|
||||
numBuffers);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
|
||||
test_write(rank, worldSize, nRanksPerNode, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory,
|
||||
devicePtr, numBuffers);
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs;
|
||||
for (auto entry : connections) {
|
||||
auto& conn = entry.second;
|
||||
epochs.insert({entry.first, std::make_shared<mscclpp::Epoch>(communicator, conn)});
|
||||
}
|
||||
communicator.setup();
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
test_write_with_device_epochs(rank, worldSize, nRanksPerNode, deviceBufferSize, communicator, bootstrap, connections,
|
||||
remoteMemory, localMemory, devicePtr, numBuffers);
|
||||
|
||||
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory,
|
||||
devicePtr, epochs, numBuffers);
|
||||
test_write_with_host_epochs(rank, worldSize, nRanksPerNode, deviceBufferSize, communicator, bootstrap, connections,
|
||||
remoteMemory, localMemory, devicePtr, numBuffers);
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
@@ -287,10 +356,10 @@ int main(int argc, char** argv)
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
|
||||
int shmWorldSize;
|
||||
MPI_Comm_size(shmcomm, &shmWorldSize);
|
||||
int nranksPerNode = shmWorldSize;
|
||||
int nRanksPerNode = shmWorldSize;
|
||||
MPI_Comm_free(&shmcomm);
|
||||
|
||||
test_communicator(rank, worldSize, nranksPerNode);
|
||||
test_communicator(rank, worldSize, nRanksPerNode);
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "checks.h"
|
||||
#include "ib.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <array>
|
||||
#include <string>
|
||||
|
||||
3
test/unit/CMakeLists.txt
Normal file
3
test/unit/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
target_sources(unit_tests PRIVATE
|
||||
core_tests.cc
|
||||
)
|
||||
49
test/unit/core_tests.cc
Normal file
49
test/unit/core_tests.cc
Normal file
@@ -0,0 +1,49 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
class LocalCommunicatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
bootstrap = std::make_shared<mscclpp::Bootstrap>(0, 1);
|
||||
comm = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::Communicator> comm;
|
||||
};
|
||||
|
||||
class MockSetuppable : public mscclpp::Setuppable {
|
||||
public:
|
||||
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
|
||||
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
|
||||
};
|
||||
|
||||
TEST_F(LocalCommunicatorTest, OnSetup) {
|
||||
auto mockSetuppable = std::make_shared<MockSetuppable>();
|
||||
comm->onSetup(mockSetuppable);
|
||||
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
|
||||
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
|
||||
comm->setup();
|
||||
}
|
||||
|
||||
TEST_F(LocalCommunicatorTest, RegisterMemory) {
|
||||
int dummy[42];
|
||||
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
|
||||
EXPECT_EQ(memory.data(), &dummy);
|
||||
EXPECT_EQ(memory.size(), sizeof(dummy));
|
||||
EXPECT_EQ(memory.rank(), 0);
|
||||
EXPECT_EQ(memory.transports(), mscclpp::NoTransports);
|
||||
}
|
||||
|
||||
TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
|
||||
int dummy[42];
|
||||
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
|
||||
comm->sendMemoryOnSetup(memory, 0, 0);
|
||||
auto memoryFuture = comm->recvMemoryOnSetup(0, 0);
|
||||
comm->setup();
|
||||
auto sameMemory = memoryFuture.get();
|
||||
EXPECT_EQ(sameMemory.size(), memory.size());
|
||||
EXPECT_EQ(sameMemory.rank(), memory.rank());
|
||||
EXPECT_EQ(sameMemory.transports(), memory.transports());
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
add_executable(bootstrap_test_cpp bootstrap_test_cpp.cc)
|
||||
target_link_libraries(bootstrap_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_executable(communicator_test_cpp communicator_test_cpp.cu)
|
||||
target_link_libraries(communicator_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_executable(allgather_test_cpp allgather_test_cpp.cu)
|
||||
target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_subdirectory(unittests)
|
||||
@@ -1,2 +0,0 @@
|
||||
add_executable(ib_test ib_test.cc)
|
||||
target_link_libraries(ib_test mscclpp MPI::MPI_CXX CUDA::cudart)
|
||||
Reference in New Issue
Block a user