diff --git a/.clang-format b/.clang-format index c3bb0335..bfd118eb 100644 --- a/.clang-format +++ b/.clang-format @@ -1,137 +1,2 @@ ---- -Language: Cpp -# BasedOnStyle: Microsoft -AccessModifierOffset: -2 -AlignAfterOpenBracket: Align -AlignConsecutiveMacros: false -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Right -AlignOperands: true -AlignTrailingComments: true -AllowAllArgumentsOnNextLine: true -AllowAllConstructorInitializersOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: Never -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: None -AllowShortLambdasOnASingleLine: All -AllowShortIfStatementsOnASingleLine: Never -AllowShortLoopsOnASingleLine: false -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: false -AlwaysBreakTemplateDeclarations: MultiLine -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterCaseLabel: false - AfterClass: true - AfterControlStatement: false # true - AfterEnum: true - AfterFunction: true - AfterNamespace: false # true - AfterObjCDeclaration: true - AfterStruct: true - AfterUnion: false - AfterExternBlock: true - BeforeCatch: false # true - BeforeElse: false # true - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Custom -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 120 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: false -ConstructorInitializerIndentWidth: 2 -ContinuationIndentWidth: 2 -Cpp11BracedListStyle: true -DeriveLineEnding: true -DerivePointerAlignment: false -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^"(llvm|llvm-c|clang|clang-c)/' - Priority: 2 - SortPriority: 0 - - Regex: '^(<|"(gtest|gmock|isl|json)/)' - Priority: 3 - SortPriority: 0 - - Regex: '.*' - Priority: 1 - SortPriority: 0 -IncludeIsMainRegex: '(Test)?$' -IncludeIsMainSourceRegex: '' -IndentCaseLabels: false -IndentExternBlock: NoIndent -IndentGotoLabels: true -IndentPPDirectives: None -IndentWidth: 2 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: true -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Auto -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 19 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 1000 -PointerAlignment: Left -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyBlock: false -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInConditionalStatement: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -SpaceBeforeSquareBrackets: false -Standard: Latest -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 2 -UseCRLF: false -UseTab: Never -... +BasedOnStyle: Google +ColumnLimit: 120 diff --git a/.github/workflows/cpplint.yml b/.github/workflows/cpplint.yml index 0b002f44..c2e6cb43 100644 --- a/.github/workflows/cpplint.yml +++ b/.github/workflows/cpplint.yml @@ -25,4 +25,8 @@ jobs: run: sudo apt-get install -y clang-format-12 - name: Run cpplint - run: make cpplint + run: | + CPPSOURCES=$(find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*" -not -path "./test/*") + PYTHONCPPSOURCES=$(find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)') + clang-format-12 -style=file --verbose --Werror --dry-run ${CPPSOURCES} + clang-format-12 --dry-run ${PYTHONCPPSOURCES} diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..01354076 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,55 @@ +cmake_minimum_required(VERSION 3.26) +project(mscclpp LANGUAGES CUDA CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +option(ENABLE_TRACE "Enable tracing" OFF) +option(USE_MPI_FOR_TESTS "Use MPI for tests" ON) +option(USE_NPKIT "Use NPKIT" ON) +option(ALLOW_GDRCOPY "Use GDRCopy, if available" OFF) + +if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set (CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}/install" CACHE PATH "default install path" FORCE) +endif() + +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + +find_package(CUDAToolkit REQUIRED) +find_package(IBVerbs REQUIRED) +find_package(NUMA REQUIRED) +if(USE_MPI_FOR_TESTS) + find_package(MPI REQUIRED) +endif() +if(ALLOW_GDRCOPY) + find_package(GDRCopy) +endif() + +include_directories(${CUDAToolkit_INCLUDE_DIRS}) +include(CTest) +include(FetchContent) +FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/b796f7d44681514f58a683a3a71ff17c94edb0c1.zip) +option(INSTALL_GTEST OFF) +FetchContent_MakeAvailable(googletest) +include(GoogleTest) + +set(CLANG_FORMAT_SOURCE_DIRS include src tests) +include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake) + +add_library(mscclpp SHARED) +target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src/include) +set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver) +if(ENABLE_TRACE) + target_compile_definitions(mscclpp PRIVATE ENABLE_TRACE) +endif() +if(USE_NPKIT) + target_compile_definitions(mscclpp PRIVATE ENABLE_NPKIT) +endif() +if(ALLOW_GDRCOPY AND GDRCOPY_FOUND) + target_compile_definitions(mscclpp PRIVATE MSCCLPP_USE_GDRCOPY) + target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy) +endif() + +add_subdirectory(include) # This adds the public headers to install with mscclpp +add_subdirectory(src) # This adds the sources to the mscclpp target +add_subdirectory(test) diff --git a/Makefile b/Makefile deleted file mode 100644 index 02233e89..00000000 --- a/Makefile +++ /dev/null @@ -1,236 +0,0 @@ -######## VERSION -MSCCLPP_MAJOR := 0 -MSCCLPP_MINOR := 1 -MSCCLPP_PATCH := 0 - -######## COMPILE OPTIONS -DEBUG ?= 0 -VERBOSE ?= 1 -TRACE ?= 0 -NPKIT ?= 0 -GDRCOPY ?= 0 -USE_MPI_FOR_TESTS ?= 1 - -######## CUDA -CUDA_HOME ?= /usr/local/cuda -CUDA_LIB ?= $(CUDA_HOME)/lib64 -CUDA_INC ?= $(CUDA_HOME)/include -NVCC = $(CUDA_HOME)/bin/nvcc -CUDA_VERSION = $(strip $(shell which $(NVCC) >/dev/null && $(NVCC) --version | grep release | sed 's/.*release //' | sed 's/\,.*//')) -CUDA_MAJOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 1) -CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2) -# You should define NVCC_GENCODE in your environment to the minimal set -# of archs to reduce compile time. -CUDA8_GENCODE = -gencode=arch=compute_50,code=sm_50 \ - -gencode=arch=compute_60,code=sm_60 \ - -gencode=arch=compute_61,code=sm_61 -CUDA9_GENCODE = -gencode=arch=compute_70,code=sm_70 -CUDA11_GENCODE = -gencode=arch=compute_80,code=sm_80 -CUDA12_GENCODE = -gencode=arch=compute_90,code=sm_90 - -CUDA8_PTX = -gencode=arch=compute_61,code=compute_61 -CUDA9_PTX = -gencode=arch=compute_70,code=compute_70 -CUDA11_PTX = -gencode=arch=compute_80,code=compute_80 -CUDA12_PTX = -gencode=arch=compute_90,code=compute_90 - -######## CXX/NVCC -CXX := g++ -NVTX ?= 1 - -ifeq ($(shell test "0$(CUDA_MAJOR)" -eq 11 -a "0$(CUDA_MINOR)" -ge 8 -o "0$(CUDA_MAJOR)" -gt 11; echo $$?),0) -# Include Hopper support if we're using CUDA11.8 or above - NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA12_GENCODE) $(CUDA12_PTX) -else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 11; echo $$?),0) - NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA11_GENCODE) $(CUDA11_PTX) -# Include Volta support if we're using CUDA9 or above -else ifeq ($(shell test "0$(CUDA_MAJOR)" -ge 9; echo $$?),0) - NVCC_GENCODE ?= $(CUDA9_GENCODE) $(CUDA9_PTX) -else - NVCC_GENCODE ?= $(CUDA8_GENCODE) $(CUDA8_PTX) -endif -$(info NVCC_GENCODE is ${NVCC_GENCODE}) - -CXXFLAGS := -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR) -fPIC -fvisibility=hidden \ - -Wall -Wno-unused-function -Wno-sign-compare -std=c++14 -Wvla \ - -I $(CUDA_INC) \ - $(CXXFLAGS) - -ifneq ($(TRACE), 0) -CXXFLAGS += -DENABLE_TRACE -endif - -NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all -# Use addprefix so that we can specify more than one path -NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt - -ifeq ($(DEBUG), 0) -NVCUFLAGS += -O3 -CXXFLAGS += -O3 -g -else -NVCUFLAGS += -O0 -G -g -CXXFLAGS += -O0 -g -ggdb3 -endif - -ifneq ($(VERBOSE), 0) -NVCUFLAGS += -Xptxas -v -Xcompiler -Wall,-Wextra,-Wno-unused-parameter -CXXFLAGS += -Wall -Wextra -else -.SILENT: -endif - -ifeq ($(NVTX), 0) -CXXFLAGS += -DNVTX_DISABLE -endif - -#### MPI (only for test code) -ifeq ($(USE_MPI_FOR_TESTS), 1) -MPI_HOME ?= /usr/local/mpi -MPI_INC := -I$(MPI_HOME)/include -MPI_LDFLAGS := -L$(MPI_HOME)/lib -lmpi -MPI_MACRO := -D MSCCLPP_USE_MPI_FOR_TESTS -else -MPI_HOME := -MPI_INC := -MPI_LDFLAGS := -MPI_MACRO := -endif - -#### GDRCOPY -ifeq ($(GDRCOPY), 1) -GDRCOPY_LDFLAGS := -lgdrapi -CXXFLAGS += -DMSCCLPP_USE_GDRCOPY -NVCUFLAGS += -DMSCCLPP_USE_GDRCOPY -else -GDRCOPY_LDFLAGS := -endif - -#### MSCCL++ -BUILDDIR ?= $(abspath ./build) -INCDIR := include -LIBDIR := lib -OBJDIR := obj -BINDIR := bin - -ifneq ($(NPKIT), 0) -CXXFLAGS += -DENABLE_NPKIT -NVCUFLAGS += -DENABLE_NPKIT -endif - -LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma - -LIBSRCS := $(addprefix src/,debug.cc utils.cc param.cc init.cc proxy.cc ib.cc config.cc) -LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) -ifneq ($(NPKIT), 0) -LIBSRCS += $(addprefix src/misc/,npkit.cc) -endif -ifeq ($(GDRCOPY), 1) -LIBSRCS += $(addprefix src/,gdr.cc) -endif -LIBOBJS := $(patsubst %.cc,%.o,$(LIBSRCS)) -LIBOBJTARGETS := $(LIBOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) - -HEADERS := $(wildcard src/include/*.h) -CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*") -PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)') - -INCEXPORTS := mscclpp.h mscclppfifo.h -INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%) - -LIBNAME := libmscclpp.so -LIBSONAME := $(LIBNAME).$(MSCCLPP_MAJOR) -LIBTARGET := $(BUILDDIR)/$(LIBDIR)/$(LIBNAME).$(MSCCLPP_MAJOR).$(MSCCLPP_MINOR).$(MSCCLPP_PATCH) - -UTDIR := tests/unittests -UTSRCS := $(addprefix $(UTDIR)/,ib_test.cc) -UTOBJS := $(patsubst %.cc,%.o,$(UTSRCS)) -UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) -UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) - -TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu) -TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) -TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) -TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) - -MSCLLPPTESTSOBJSDIR:= $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR) -MSCLLPPTESTBINFILESLIST := allgather_test allreduce_test sendrecv_test -MSCLLPPTESTBINS := $(MSCLLPPTESTBINFILESLIST:%=$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf) - -INCLUDE := -Isrc -Isrc/include - -.PHONY: all build lib tests mscclpp-test clean - -all: build - -build: lib tests -ifeq ($(USE_MPI_FOR_TESTS), 1) -build: lib tests mscclpp-test -endif - -lib: $(LIBOBJTARGETS) $(INCTARGETS) $(LIBTARGET) - -unittests: $(UTBINS) - -tests: unittests $(TESTSBINS) - -mscclpp-test: $(LIBTARGET) $(MSCLLPPTESTBINS) - -cpplint: - clang-format-12 -style=file --verbose --Werror --dry-run $(CPPSOURCES) - clang-format-12 --dry-run $(PYTHONCPPSOURCES) - -cpplint-autofix: - clang-format-12 -style=file --verbose --Werror -i $(CPPSOURCES) - clang-format-12 -i $(PYTHONCPPSOURCES) - -# Run cpplint on a single file, example: make cpplint-file-autofix INPUTFILE=src/bootstrap/bootstrap.cc -cpplint-file-autofix: - clang-format-12 -style=file --verbose --Werror -i $(INPUTFILE) - -# Compile libobjs -$(BUILDDIR)/$(OBJDIR)/%.o: %.cc $(HEADERS) - @mkdir -p $(@D) - $(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $< - -# Compile utobjs -$(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o: $(UTDIR)/%.cc $(HEADERS) - @mkdir -p $(@D) - $(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $< - -$(BUILDDIR)/$(INCDIR)/%.h: src/$(INCDIR)/%.h - @mkdir -p $(@D) - cp $< $@ - -$(LIBTARGET): $(LIBOBJTARGETS) - @mkdir -p $(@D) - $(CXX) -shared -Wl,--no-as-needed -Wl,-soname,$(LIBSONAME) -o $@ $^ $(CXXFLAGS) $(LDFLAGS) - ln -sf $(LIBTARGET) $(BUILDDIR)/$(LIBDIR)/$(LIBNAME) - ln -sf $(LIBTARGET) $(BUILDDIR)/$(LIBDIR)/$(LIBSONAME) - -# UT bins -$(BUILDDIR)/$(BINDIR)/$(UTDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o $(LIBOBJTARGETS) - @mkdir -p $(@D) - $(NVCC) -o $@ $+ $(MPI_LDFLAGS) $(LDFLAGS) - -# Compile .cc tests -$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cc $(INCTARGETS) - @mkdir -p $(@D) - $(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -c $< $(MPI_MACRO) - -# Compile .cu tests -$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cu $(INCTARGETS) - @mkdir -p $(@D) - $(NVCC) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(NVCUFLAGS) $(INCLUDE) -c $< $(MPI_MACRO) - -# Test bins -$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o $(LIBTARGET) - @mkdir -p $(@D) - $(NVCC) -o $@ $< $(MPI_LDFLAGS) -L$(BUILDDIR)/$(LIBDIR) -lmscclpp - -# Compile mscclpp_test -$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf: $(MSCLLPPTESTSOBJSDIR)/%.o $(MSCLLPPTESTSOBJSDIR)/common.o - @mkdir -p $(@D) - $(NVCC) -o $@ $^ $(MPI_LDFLAGS) -L$(BUILDDIR)/$(LIBDIR) -lmscclpp - -clean: - rm -rf $(BUILDDIR) diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..677b46cf --- /dev/null +++ b/TODO.md @@ -0,0 +1,8 @@ +# Core API extraction + +- Add a test for host side Communicator/RegisteredMemory/Connection use. +- Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this. +- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. Probably need a manager class to wrap all of this. +- Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes. +- Implement IbQp::~IbQp() +- Fix RegisteredMemory::Impl::Impl to get the IPC handle from the base pointer, not the derived pointer. \ No newline at end of file diff --git a/cmake/AddClangFormatTargets.cmake b/cmake/AddClangFormatTargets.cmake new file mode 100644 index 00000000..49e142a3 --- /dev/null +++ b/cmake/AddClangFormatTargets.cmake @@ -0,0 +1,18 @@ +# Add targets to run clang-format + +find_program(CLANG_FORMAT clang-format) +if(CLANG_FORMAT) + message(STATUS "Found clang-format: ${CLANG_FORMAT}") + set(CLANG_FORMAT_FILE_TYPES *.h *.hpp *.c *.cc *.cpp *.cu) + # Produce combinations of source directories and file types + foreach(SOURCE_DIR ${CLANG_FORMAT_SOURCE_DIRS}) + foreach(FILE_TYPE ${CLANG_FORMAT_FILE_TYPES}) + list(APPEND CLANG_FORMAT_SOURCE_PATTERNS ${SOURCE_DIR}/${FILE_TYPE}) + endforeach() + endforeach() + file(GLOB_RECURSE CLANG_FORMAT_SOURCES ${CLANG_FORMAT_SOURCE_PATTERNS}) + add_custom_target(check-format ALL COMMAND ${CLANG_FORMAT} -style=file --dry-run ${CLANG_FORMAT_SOURCES}) + add_custom_target(format COMMAND ${CLANG_FORMAT} -style=file -i ${CLANG_FORMAT_SOURCES}) +else() + message(STATUS "clang-format not found.") +endif() diff --git a/cmake/FindGDRCopy.cmake b/cmake/FindGDRCopy.cmake new file mode 100644 index 00000000..cde447ba --- /dev/null +++ b/cmake/FindGDRCopy.cmake @@ -0,0 +1,41 @@ +# Find the GDRCopy libraries +# +# The following variables are optionally searched for defaults +# GDRCOPY_ROOT_DIR: Base directory where all GDRCopy components are found +# GDRCOPY_INCLUDE_DIR: Directory where GDRCopy headers are found +# GDRCOPY_LIB_DIR: Directory where GDRCopy libraries are found + +# The following are set after configuration is done: +# GDRCOPY_FOUND +# GDRCOPY_INCLUDE_DIRS +# GDRCOPY_LIBRARIES + +# An imported target MSCCLPP::gdrcopy is created if the library is found. + +find_path(GDRCOPY_INCLUDE_DIRS + NAMES gdrapi.h + HINTS + ${GDRCOPY_INCLUDE_DIR} + ${GDRCOPY_ROOT_DIR} + ${GDRCOPY_ROOT_DIR}/include) + +find_library(GDRCOPY_LIBRARIES + NAMES gdrapi + HINTS + ${GDRCOPY_LIB_DIR} + ${GDRCOPY_ROOT_DIR} + ${GDRCOPY_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(GDRCopy DEFAULT_MSG GDRCOPY_INCLUDE_DIRS GDRCOPY_LIBRARIES) +mark_as_advanced(GDRCOPY_INCLUDE_DIR GDRCOPY_LIBRARIES) + +if(GDRCOPY_FOUND) + if(NOT TARGET MSCCLPP::gdrcopy) + add_library(MSCCLPP::gdrcopy UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::gdrcopy PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GDRCOPY_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${GDRCOPY_LIBRARIES}") +endif() \ No newline at end of file diff --git a/cmake/FindIBVerbs.cmake b/cmake/FindIBVerbs.cmake new file mode 100644 index 00000000..fc80b11c --- /dev/null +++ b/cmake/FindIBVerbs.cmake @@ -0,0 +1,41 @@ +# Find the IB Verbs libraries +# +# The following variables are optionally searched for defaults +# IBVERBS_ROOT_DIR: Base directory where all ibverbs components are found +# IBVERBS_INCLUDE_DIR: Directory where ibverbs headers are found +# IBVERBS_LIB_DIR: Directory where ibverbs libraries are found + +# The following are set after configuration is done: +# IBVERBS_FOUND +# IBVERBS_INCLUDE_DIRS +# IBVERBS_LIBRARIES + +# An imported target MSCCLPP::ibverbs is created if the library is found. + +find_path(IBVERBS_INCLUDE_DIRS + NAMES infiniband/verbs.h + HINTS + ${IBVERBS_INCLUDE_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/include) + +find_library(IBVERBS_LIBRARIES + NAMES ibverbs + HINTS + ${IBVERBS_LIB_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(IBVerbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES) +mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES) + +if(IBVERBS_FOUND) + if(NOT TARGET MSCCLPP::ibverbs) + add_library(MSCCLPP::ibverbs UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::ibverbs PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${IBVERBS_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${IBVERBS_LIBRARIES}") +endif() \ No newline at end of file diff --git a/cmake/FindNUMA.cmake b/cmake/FindNUMA.cmake new file mode 100644 index 00000000..70e04d53 --- /dev/null +++ b/cmake/FindNUMA.cmake @@ -0,0 +1,41 @@ +# Find the numa libraries +# +# The following variables are optionally searched for defaults +# NUMA_ROOT_DIR: Base directory where all numa components are found +# NUMA_INCLUDE_DIR: Directory where numa headers are found +# NUMA_LIB_DIR: Directory where numa libraries are found + +# The following are set after configuration is done: +# NUMA_FOUND +# NUMA_INCLUDE_DIRS +# NUMA_LIBRARIES + +# An imported target MSCCLPP::numa is created if the library is found. + +find_path(NUMA_INCLUDE_DIRS + NAMES numa.h + HINTS + ${NUMA_INCLUDE_DIR} + ${NUMA_ROOT_DIR} + ${NUMA_ROOT_DIR}/include) + +find_library(NUMA_LIBRARIES + NAMES numa + HINTS + ${NUMA_LIB_DIR} + ${NUMA_ROOT_DIR} + ${NUMA_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_INCLUDE_DIRS NUMA_LIBRARIES) +mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARIES) + +if(NUMA_FOUND) + if(NOT TARGET MSCCLPP::numa) + add_library(MSCCLPP::numa UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::numa PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${NUMA_LIBRARIES}") +endif() \ No newline at end of file diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt new file mode 100644 index 00000000..b5fa7984 --- /dev/null +++ b/include/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE HEADERS CONFIGURE_DEPENDS *.hpp) +target_sources(mscclpp PUBLIC FILE_SET HEADERS FILES ${HEADERS}) +install(TARGETS mscclpp FILE_SET HEADERS) diff --git a/include/mscclpp/channel.hpp b/include/mscclpp/channel.hpp new file mode 100644 index 00000000..474244ce --- /dev/null +++ b/include/mscclpp/channel.hpp @@ -0,0 +1,259 @@ +#ifndef MSCCLPP_CHANNEL_HPP_ +#define MSCCLPP_CHANNEL_HPP_ + +#include +#include +#include +#include + +namespace mscclpp { +namespace channel { + +// A Channel pairs a Connection with an Epoch +class Channel { + public: + Channel(Communicator& communicator, std::shared_ptr connection) + : connection_(connection), epoch_(std::make_shared(communicator, connection)){}; + + Connection& connection() { return *connection_; } + DeviceEpoch& epoch() { return *epoch_; } + + private: + std::shared_ptr connection_; + std::shared_ptr epoch_; +}; + +using ChannelId = uint32_t; + +using TriggerType = uint64_t; +const TriggerType TriggerData = 0x1; +const TriggerType TriggerFlag = 0x2; +const TriggerType TriggerSync = 0x4; + +// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles +// mapping to the actual +using MemoryId = uint32_t; + +#define MSCCLPP_BITS_SIZE 32 +#define MSCCLPP_BITS_OFFSET 32 +#define MSCCLPP_BITS_REGMEM_HANDLE 8 +#define MSCCLPP_BITS_TYPE 3 +#define MSCCLPP_BITS_CONNID 10 + +// this is the basic structure of each work element in the fifo +// the summation of number of bits must be 128 or less +union ChannelTrigger { + ProxyTrigger value; + struct { + // first 64 bits: value[0] + uint64_t size : MSCCLPP_BITS_SIZE; + uint64_t srcOffset : MSCCLPP_BITS_OFFSET; + uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + // second 64 bits: value[1] + uint64_t dstOffset : MSCCLPP_BITS_OFFSET; + uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t chanId : MSCCLPP_BITS_CONNID; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - + MSCCLPP_BITS_TYPE); // ensure 64-bit alignment + } fields; + +#ifdef __CUDACC__ + __device__ ChannelTrigger() {} + __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} + __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size, int connectionId) { + value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); + value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) + << MSCCLPP_BITS_REGMEM_HANDLE) + + src) + << MSCCLPP_BITS_OFFSET) + + dstOffset); + } +#endif // __CUDACC__ +}; + +struct DeviceChannel { + DeviceChannel() = default; + + DeviceChannel(ChannelId channelId, DeviceEpoch::DeviceHandle epoch, DeviceProxyFifo fifo) + : channelId_(channelId), epoch_(epoch), fifo_(fifo) {} + + DeviceChannel(const DeviceChannel& other) = default; + + DeviceChannel& operator=(DeviceChannel& other) = default; + +#ifdef __CUDACC__ + __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size) { + fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value); + } + + __forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + put(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void signal() { + epochIncrement(); + fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value); + } + + __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size) { + epochIncrement(); + fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value); + } + + __forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + putWithSignal(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, + uint64_t srcOffset, uint64_t size) { + epochIncrement(); + uint64_t curFifoHead = fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_) + .value); + while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo_.tailReplica <= curFifoHead) + ; + } + + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + putWithSignalAndFlush(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void flush() { + uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, channelId_).value); + // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail + // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. + while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo_.tailReplica <= curFifoHead) + ; + } + + __forceinline__ __device__ void wait() { epoch_.wait(); } + + __forceinline__ __device__ void epochIncrement() { epoch_.epochIncrement(); } +#endif // __CUDACC__ + + ChannelId channelId_; + + DeviceEpoch::DeviceHandle epoch_; + + // this is a concurrent fifo which is multiple threads from the device + // can produce for and the sole proxy thread consumes it. + DeviceProxyFifo fifo_; +}; + +class DeviceChannelService; + +inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService); + +class DeviceChannelService { + public: + DeviceChannelService(Communicator& communicator); + + ChannelId addChannel(std::shared_ptr connection) { + channels_.push_back(Channel(communicator_, connection)); + return channels_.size() - 1; + } + + MemoryId addMemory(RegisteredMemory memory) { + memories_.push_back(memory); + return memories_.size() - 1; + } + + Channel channel(ChannelId id) { return channels_[id]; } + DeviceChannel deviceChannel(ChannelId id) { + return DeviceChannel(id, channels_[id].epoch().deviceHandle(), proxy_.fifo().deviceFifo()); + } + + void startProxy() { proxy_.start(); } + void stopProxy() { proxy_.stop(); } + + private: + Communicator& communicator_; + std::vector channels_; + std::vector memories_; + Proxy proxy_; + int deviceNumaNode; + + void bindThread(); + + ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) { + ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); + Channel& channel = channels_[trigger->fields.chanId]; + + auto result = ProxyHandlerResult::Continue; + + if (trigger->fields.type & TriggerData) { + RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; + RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; + channel.connection().write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size); + } + + if (trigger->fields.type & TriggerFlag) { + channel.epoch().signal(); + } + + if (trigger->fields.type & TriggerSync) { + channel.connection().flush(); + result = ProxyHandlerResult::FlushFifoTailAndContinue; + } + + return result; + } +}; + +struct SimpleDeviceChannel { + SimpleDeviceChannel() = default; + + SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {} + + SimpleDeviceChannel(const SimpleDeviceChannel& other) = default; + + SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default; + +#ifdef __CUDACC__ + + __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + devChan_.put(dst_, dstOffset, src_, srcOffset, size); + } + + __forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); } + + __forceinline__ __device__ void signal() { devChan_.signal(); } + + __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + devChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) { + putWithSignalAndFlush(offset, offset, size); + } + + __forceinline__ __device__ void flush() { devChan_.flush(); } + + __forceinline__ __device__ void wait() { devChan_.wait(); } + + __forceinline__ __device__ void epochIncrement() { devChan_.epochIncrement(); } + +#endif // __CUDACC__ + + DeviceChannel devChan_; + MemoryId dst_; + MemoryId src_; +}; + +} // namespace channel +} // namespace mscclpp + +#endif // MSCCLPP_CHANNEL_HPP_ diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp new file mode 100644 index 00000000..b6249bfd --- /dev/null +++ b/include/mscclpp/core.hpp @@ -0,0 +1,296 @@ +#ifndef MSCCLPP_CORE_HPP_ +#define MSCCLPP_CORE_HPP_ + +#define MSCCLPP_MAJOR 0 +#define MSCCLPP_MINOR 1 +#define MSCCLPP_PATCH 0 +#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH) + +#include +#include +#include +#include +#include +#include + +namespace mscclpp { + +#define MSCCLPP_UNIQUE_ID_BYTES 128 +struct UniqueId { + char internal[MSCCLPP_UNIQUE_ID_BYTES]; +}; + +class BaseBootstrap { + public: + BaseBootstrap(){}; + virtual ~BaseBootstrap() = default; + virtual int getRank() = 0; + virtual int getNranks() = 0; + virtual void send(void* data, int size, int peer, int tag) = 0; + virtual void recv(void* data, int size, int peer, int tag) = 0; + virtual void allGather(void* allData, int size) = 0; + virtual void barrier() = 0; + + // TODO: move implementations of these helpers out of this header + void send(const std::vector& data, int peer, int tag) { + size_t size = data.size(); + send((void*)&size, sizeof(size_t), peer, tag); + send((void*)data.data(), data.size(), peer, tag + 1); + } + void recv(std::vector& data, int peer, int tag) { + size_t size; + recv((void*)&size, sizeof(size_t), peer, tag); + data.resize(size); + recv((void*)data.data(), data.size(), peer, tag + 1); + } +}; + +class Bootstrap : public BaseBootstrap { + public: + Bootstrap(int rank, int nRanks); + ~Bootstrap(); + + UniqueId createUniqueId(); + UniqueId getUniqueId() const; + + void initialize(UniqueId uniqueId); + void initialize(std::string ipPortPair); + int getRank() override; + int getNranks() override; + void send(void* data, int size, int peer, int tag) override; + void recv(void* data, int size, int peer, int tag) override; + void allGather(void* allData, int size) override; + void barrier() override; + + private: + class Impl; + std::unique_ptr pimpl_; +}; + +/* Create a unique ID for communication. Only needs to be called by one process. + * Use with mscclppCommInitRankFromId(). + * All processes need to provide the same ID to mscclppCommInitRankFromId(). + * + * Outputs: + * uniqueId: the unique ID to be created + */ +std::unique_ptr getUniqueId(); + +enum class Transport { Unknown, CudaIpc, IB0, IB1, IB2, IB3, IB4, IB5, IB6, IB7, NumTransports }; + +namespace detail { +const size_t TransportFlagsSize = 10; +static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), + "TransportFlagsSize must match the number of transports"); +using TransportFlagsBase = std::bitset; +} // namespace detail + +class TransportFlags : private detail::TransportFlagsBase { + public: + TransportFlags() = default; + TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast(transport)) {} + + bool has(Transport transport) const { return detail::TransportFlagsBase::test(static_cast(transport)); } + + bool none() const { return detail::TransportFlagsBase::none(); } + + bool any() const { return detail::TransportFlagsBase::any(); } + + bool all() const { return detail::TransportFlagsBase::all(); } + + size_t count() const { return detail::TransportFlagsBase::count(); } + + TransportFlags& operator|=(TransportFlags other) { + detail::TransportFlagsBase::operator|=(other); + return *this; + } + + TransportFlags operator|(TransportFlags other) const { return TransportFlags(*this) |= other; } + + TransportFlags operator|(Transport transport) const { return *this | TransportFlags(transport); } + + TransportFlags& operator&=(TransportFlags other) { + detail::TransportFlagsBase::operator&=(other); + return *this; + } + + TransportFlags operator&(TransportFlags other) const { return TransportFlags(*this) &= other; } + + TransportFlags operator&(Transport transport) const { return *this & TransportFlags(transport); } + + TransportFlags& operator^=(TransportFlags other) { + detail::TransportFlagsBase::operator^=(other); + return *this; + } + + TransportFlags operator^(TransportFlags other) const { return TransportFlags(*this) ^= other; } + + TransportFlags operator^(Transport transport) const { return *this ^ TransportFlags(transport); } + + TransportFlags operator~() const { return TransportFlags(*this).flip(); } + + bool operator==(TransportFlags other) const { return detail::TransportFlagsBase::operator==(other); } + + bool operator!=(TransportFlags other) const { return detail::TransportFlagsBase::operator!=(other); } + + detail::TransportFlagsBase toBitset() const { return *this; } + + private: + TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {} +}; + +inline TransportFlags operator|(Transport transport1, Transport transport2) { + return TransportFlags(transport1) | transport2; +} + +inline TransportFlags operator&(Transport transport1, Transport transport2) { + return TransportFlags(transport1) & transport2; +} + +inline TransportFlags operator^(Transport transport1, Transport transport2) { + return TransportFlags(transport1) ^ transport2; +} + +const TransportFlags NoTransports = TransportFlags(); +const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | + Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; +const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc; + +int getIBDeviceCount(); +std::string getIBDeviceName(Transport ibTransport); +Transport getIBTransportByDeviceName(const std::string& ibDeviceName); + +class Communicator; +class Connection; + +class RegisteredMemory { + struct Impl; + // A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated + // lazily. + std::shared_ptr pimpl; + + public: + RegisteredMemory() = default; + RegisteredMemory(std::shared_ptr pimpl); + ~RegisteredMemory(); + + void* data(); + size_t size(); + int rank(); + TransportFlags transports(); + + std::vector serialize(); + static RegisteredMemory deserialize(const std::vector& data); + + friend class Connection; + friend class Communicator; +}; + +class Connection { + public: + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) = 0; + + virtual void flush() = 0; + + virtual int remoteRank() = 0; + + virtual int tag() = 0; + + virtual Transport transport() = 0; + + virtual Transport remoteTransport() = 0; + + protected: + static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); +}; + +struct Setuppable { + virtual void beginSetup(std::shared_ptr) {} + virtual void endSetup(std::shared_ptr) {} +}; + +template +class NonblockingFuture { + std::shared_future future; + + public: + NonblockingFuture() = default; + NonblockingFuture(std::shared_future&& future) : future(std::move(future)) {} + NonblockingFuture(const NonblockingFuture&) = default; + + bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; } + + T get() { + if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage); + return future.get(); + } +}; + +class Communicator { + public: + /* Initialize the communicator. + * + * Inputs: + * bootstrap: an implementation of the of BaseBootstrap that the communicator will use + */ + Communicator(std::shared_ptr bootstrap); + + ~Communicator(); + + /* Return the bootstrapper held by this communicator. */ + std::shared_ptr bootstrapper(); + + /* Register a region of GPU memory for use in this communicator. + * + * Inputs: + * data: base pointer to the memory + * size: size of the memory region in bytes + * + * Returns: a handle to the buffer + */ + RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); + + void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag); + + NonblockingFuture recvMemoryOnSetup(int remoteRank, int tag); + + /* Connect to a remote rank. This function only prepares metadata for connection. The actual connection + * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection + * from rank i to remote rank j needs to have a counterpart from rank j to rank i. + * Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages + * and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has + * security risks if the devConn's accesses are given to a malicious process. + * + * Inputs: + * remoteRank: the rank of the remote process + * tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be + * used to identify the connection inside a GPU kernel. + * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) + * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. + */ + std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport); + + /* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */ + void onSetup(std::shared_ptr setuppable); + + /* Setup all objects that have registered for setup. This includes any connections created by connect(). */ + void setup(); + + struct Impl; + + private: + std::unique_ptr pimpl; +}; +} // namespace mscclpp + +namespace std { +template <> +struct hash { + size_t operator()(const mscclpp::TransportFlags& flags) const { + return hash()(flags.toBitset()); + } +}; +} // namespace std + +#endif // MSCCLPP_CORE_HPP_ diff --git a/include/mscclpp/epoch.hpp b/include/mscclpp/epoch.hpp new file mode 100644 index 00000000..539ad03f --- /dev/null +++ b/include/mscclpp/epoch.hpp @@ -0,0 +1,67 @@ +#ifndef MSCCLPP_EPOCH_HPP_ +#define MSCCLPP_EPOCH_HPP_ + +#include + +namespace mscclpp { + +struct alignas(16) EpochIds { + uint64_t outbound; + uint64_t inboundReplica; +}; + +class BaseEpoch { + private: + std::shared_ptr connection_; + RegisteredMemory localEpochIdsRegMem_; + NonblockingFuture remoteEpochIdsRegMem_; + + protected: + EpochIds* epochIds_; + uint64_t* expectedInboundEpochId_; + + public: + BaseEpoch(std::shared_ptr connection); + void setup(Communicator& communicator); + BaseEpoch(const BaseEpoch&) = delete; + void signal(); +}; + +class DeviceEpoch : BaseEpoch { + public: + DeviceEpoch(Communicator& communicator, std::shared_ptr connection); + DeviceEpoch(const DeviceEpoch&) = delete; + ~DeviceEpoch(); + void signal(); + + struct DeviceHandle { +#ifdef __CUDACC__ + __forceinline__ __device__ void wait() { + (*expectedInboundEpochId) += 1; + while (*(volatile uint64_t*)&(epochIds->inboundReplica) < (*expectedInboundEpochId)) + ; + } + + __forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(epochIds->outbound) += 1; } +#endif // __CUDACC__ + + EpochIds* epochIds; + uint64_t* expectedInboundEpochId; + }; + + DeviceHandle deviceHandle(); +}; + +class HostEpoch : BaseEpoch { + public: + HostEpoch(Communicator& communicator, std::shared_ptr connection); + HostEpoch(const HostEpoch&) = delete; + ~HostEpoch(); + + void increamentAndSignal(); + void wait(); +}; + +} // namespace mscclpp + +#endif // MSCCLPP_EPOCH_HPP_ diff --git a/include/mscclpp/errors.hpp b/include/mscclpp/errors.hpp new file mode 100644 index 00000000..cdaf9ed8 --- /dev/null +++ b/include/mscclpp/errors.hpp @@ -0,0 +1,60 @@ +#ifndef MSCCLPP_ERRORS_HPP_ +#define MSCCLPP_ERRORS_HPP_ + +#include +#include + +#include + +namespace mscclpp { + +enum class ErrorCode { + SystemError, + InternalError, + InvalidUsage, +}; + +std::string errorToString(enum ErrorCode error); + +class BaseError : public std::runtime_error { + public: + BaseError(std::string message, int errorCode); + explicit BaseError(int errorCode); + virtual ~BaseError() = default; + int getErrorCode() const; + const char* what() const noexcept override; + + private: + int errorCode_; + + protected: + std::string message_; +}; + +class Error : public BaseError { + public: + Error(std::string message, ErrorCode errorCode); + virtual ~Error() = default; +}; + +class CudaError : public BaseError { + public: + CudaError(std::string message, cudaError_t errorCode); + virtual ~CudaError() = default; +}; + +class CuError : public BaseError { + public: + CuError(std::string message, CUresult errorCode); + virtual ~CuError() = default; +}; + +class IbError : public BaseError { + public: + IbError(std::string message, int errorCode); + virtual ~IbError() = default; +}; + +}; // namespace mscclpp + +#endif // MSCCLPP_ERRORS_HPP_ diff --git a/include/mscclpp/fifo.hpp b/include/mscclpp/fifo.hpp new file mode 100644 index 00000000..aff86f8f --- /dev/null +++ b/include/mscclpp/fifo.hpp @@ -0,0 +1,74 @@ +#ifndef MSCCLPP_FIFO_HPP_ +#define MSCCLPP_FIFO_HPP_ + +#include + +#include +#include + +namespace mscclpp { + +// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered. +// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. +#define MSCCLPP_PROXY_FIFO_SIZE 128 +#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 + +struct alignas(16) ProxyTrigger { + uint64_t fst, snd; +}; + +/* This is a concurrent fifo where multiple device threads can push mscclppTrigger work elements to + * and a single host proxy thread consumes these work elements. There is a head pointer allocated on device + * which starts with 0 and goes to 2^64-1 which is almost infinity. There are two copies of tail, one + * that is on the deivce (tailReplica) and another that is on host (proxyState->fifoTailHost). + * The host always has the "true" tail and occasionally, pushes it to the copy on the device. + * Therefore, most of the time, the device has a stale version. The invariants are: + * tailReplica <= proxyState->fifoTailHost <= head. + * push() function increments head, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService + * and it occasionally flushes it to tailReplica via a cudaMemcpyAsync. + * + * Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates + * for the tail as there is usually enough space for device threads to push their work into. + */ +struct DeviceProxyFifo { +#ifdef __CUDACC__ + __forceinline__ __device__ uint64_t push(ProxyTrigger trigger) { + uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1); + while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica)) + ; + while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0) + ; + ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]); + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); + return curFifoHead; + } +#endif // __CUDACC__ + + ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements + uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused + // occasionally to device + uint64_t* head; // Allocated on device. Only accessed by device +}; + +class HostProxyFifo { + public: + HostProxyFifo(); + + ~HostProxyFifo(); + + void poll(ProxyTrigger* trigger); + + void pop(); + + void flushTail(bool sync = false); + + DeviceProxyFifo deviceFifo(); + + private: + struct Impl; + std::unique_ptr pimpl; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_FIFO_HPP_ diff --git a/include/mscclpp/proxy.hpp b/include/mscclpp/proxy.hpp new file mode 100644 index 00000000..4e89e56b --- /dev/null +++ b/include/mscclpp/proxy.hpp @@ -0,0 +1,37 @@ +#ifndef MSCCLPP_PROXY_HPP_ +#define MSCCLPP_PROXY_HPP_ + +#include +#include +#include + +namespace mscclpp { + +enum class ProxyHandlerResult { + Continue, + FlushFifoTailAndContinue, + Stop, +}; + +class Proxy; +using ProxyHandler = std::function; + +class Proxy { + public: + Proxy(ProxyHandler handler, std::function threadInit); + Proxy(ProxyHandler handler); + ~Proxy(); + + void start(); + void stop(); + + HostProxyFifo& fifo(); + + private: + struct Impl; + std::unique_ptr pimpl; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_PROXY_HPP_ \ No newline at end of file diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index e627dcca..3171122d 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -127,6 +127,21 @@ struct _Comm { } }; +struct _P2PHandle { + struct mscclppRegisteredMemoryP2P _rmP2P; + struct mscclppIbMr _ibmr; + + _P2PHandle() : _rmP2P({0}), _ibmr({0}) {} + + _P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) { + _rmP2P = p2p; + if (_rmP2P.IbMr != nullptr) { + _ibmr = *_rmP2P.IbMr; + _rmP2P.IbMr = &_ibmr; + } + } +}; + nb::callable _log_callback; void _LogHandler(const char* msg) { @@ -141,6 +156,8 @@ static const std::string DOC_MscclppUniqueId = static const std::string DOC__Comm = "MSCCLPP Communications Handle"; +static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle"; + NB_MODULE(_py_mscclpp, m) { m.doc() = "Python bindings for MSCCLPP: which is not NCCL"; @@ -191,6 +208,9 @@ NB_MODULE(_py_mscclpp, m) { return nb::bytes(id.internal, sizeof(id.internal)); }); + nb::class_<_P2PHandle>(m, "_P2PHandle") + .def_ro_static("__doc__", &DOC__P2PHandle); + nb::class_<_Comm>(m, "_Comm") .def_ro_static("__doc__", &DOC__Comm) .def_static( @@ -239,6 +259,31 @@ NB_MODULE(_py_mscclpp, m) { "Is this comm object closed?") .def_ro("rank", &_Comm::_rank) .def_ro("world_size", &_Comm::_world_size) + .def( + "register_buffer", + [](_Comm& comm, + uint64_t local_buff, + uint64_t buff_size) -> std::vector<_P2PHandle> { + comm.check_open(); + mscclppRegisteredMemory regMem; + checkResult( + mscclppRegisterBuffer( + comm._handle, + reinterpret_cast(local_buff), + buff_size, + ®Mem), + "Registering buffer failed"); + + std::vector<_P2PHandle> handles; + for (const auto& p2p : regMem.p2p) { + handles.push_back(_P2PHandle(p2p)); + } + return handles; + }, + "local_buf"_a, + "buff_size"_a, + nb::call_guard(), + "Register a buffer for P2P transfers.") .def( "connect", [](_Comm& comm, @@ -247,9 +292,7 @@ NB_MODULE(_py_mscclpp, m) { uint64_t local_buff, uint64_t buff_size, mscclppTransport_t transport_type) -> void { - if (comm._proxies_running) { - throw std::invalid_argument("Proxy Threads Already Running"); - } + comm.check_open(); RETRY( mscclppConnect( comm._handle, @@ -308,15 +351,6 @@ NB_MODULE(_py_mscclpp, m) { "Start the MSCCLPP proxy.") .def("close", &_Comm::close, nb::call_guard()) .def("__del__", &_Comm::close, nb::call_guard()) - .def( - "connection_setup", - [](_Comm& comm) -> void { - comm.check_open(); - checkResult( - mscclppConnectionSetup(comm._handle), - "Connection Setup Failed"); - }, - nb::call_guard()) .def( "bootstrap_all_gather_int", [](_Comm& comm, int val) -> std::vector { diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index cbb84c2c..51c564d8 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -18,6 +18,7 @@ __all__ = ( ) _Comm = _py_mscclpp._Comm +_P2PHandle = _py_mscclpp._P2PHandle TransportType = _py_mscclpp.TransportType MscclppUniqueId = _py_mscclpp.MscclppUniqueId @@ -46,6 +47,7 @@ MSCCLPP_LOG_LEVELS: set[str] = { "TRACE", } + def _setup_logging(level: str = "INFO"): """Setup log hooks for the C library.""" level = level.upper() @@ -177,3 +179,25 @@ class Comm: def stop_proxies(self) -> None: self._comm.stop_proxies() + + def register_buffer( + self, + data_ptr, + data_size: int, + ) -> list[_P2PHandle]: + return [ + P2PHandle(self, h) + for h in self._comm.register_buffer( + data_ptr, + data_size, + ) + ] + + +class P2PHandle: + _comm: Comm + _handle: _P2PHandle + + def __init__(self, comm: Comm, handle: _P2PHandle): + self._comm = comm + self._handle = handle diff --git a/python/src/mscclpp/test_mscclpp.py b/python/src/mscclpp/test_mscclpp.py index 6b162f7d..d4159ad5 100644 --- a/python/src/mscclpp/test_mscclpp.py +++ b/python/src/mscclpp/test_mscclpp.py @@ -79,6 +79,8 @@ class CommsTest(unittest.TestCase): if errors: parts = [] for rank, content in errors: - parts.append(f"[rank {rank}]: " + content.decode('utf-8', errors='ignore')) + parts.append( + f"[rank {rank}]: " + content.decode("utf-8", errors="ignore") + ) raise AssertionError("\n\n".join(parts)) diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index 6f5c5ec7..0d5d8aa1 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -73,7 +73,7 @@ def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp. comm.connection_setup() -def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm): +def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm): rank = options.rank buf = torch.zeros([options.world_size], dtype=torch.int64) @@ -95,6 +95,12 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm): mscclpp.TransportType.P2P, ) + handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel()) + hamcrest.assert_that( + handles, + hamcrest.has_length(options.world_size - 1), + ) + torch.cuda.synchronize() comm.connection_setup() @@ -103,7 +109,6 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm): comm.stop_proxies() - def main(): p = argparse.ArgumentParser() p.add_argument("--rank", type=int, required=True) @@ -131,7 +136,7 @@ def main(): _test_bootstrap_allgather_bytes(options, comm) _test_bootstrap_allgather_json(options, comm) _test_bootstrap_allgather_pickle(options, comm) - _test_p2p_connect(options, comm) + _test_rm(options, comm) finally: comm.close() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..dc86f638 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc) +target_sources(mscclpp PRIVATE ${SOURCES}) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 064af4a8..7efe46ae 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -1,567 +1,439 @@ -/************************************************************************* - * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "bootstrap.h" -#include "config.h" -#include "core.h" -#include "mscclpp.h" -#include "utils.h" -#include -#include - -struct bootstrapRootArgs -{ - struct mscclppSocket* listenSock; - uint64_t magic; -}; - -/* Init functions */ -static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1]; -static union mscclppSocketAddress bootstrapNetIfAddr; -static int bootstrapNetInitDone = 0; -pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; - -mscclppResult_t bootstrapNetInit(const char* ip_port_pair) -{ - if (bootstrapNetInitDone == 0) { - pthread_mutex_lock(&bootstrapNetLock); - if (bootstrapNetInitDone == 0) { - const char* env; - if (ip_port_pair) { - env = ip_port_pair; - } else { - env = getenv("MSCCLPP_COMM_ID"); - } - if (env) { - union mscclppSocketAddress remoteAddr; - if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) { - WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); - return mscclppInvalidArgument; - } - if (mscclppFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, - 1) <= 0) { - WARN("NET/Socket : No usable listening interface found"); - return mscclppSystemError; - } - } else { - int nIfs = mscclppFindInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1); - if (nIfs <= 0) { - WARN("Bootstrap : no socket interface found"); - return mscclppInternalError; - } - } - char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; - sprintf(line, " %s:", bootstrapNetIfName); - mscclppSocketToString(&bootstrapNetIfAddr, line + strlen(line)); - INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line); - bootstrapNetInitDone = 1; - } - pthread_mutex_unlock(&bootstrapNetLock); - } - return mscclppSuccess; -} - -/* Socket Interface Selection type */ -enum bootstrapInterface_t -{ - findSubnetIf = -1, - dontCareIf = -2 -}; - -// Additional sync functions -static mscclppResult_t bootstrapNetSend(struct mscclppSocket* sock, void* data, int size) -{ - MSCCLPPCHECK(mscclppSocketSend(sock, &size, sizeof(int))); - MSCCLPPCHECK(mscclppSocketSend(sock, data, size)); - return mscclppSuccess; -} -static mscclppResult_t bootstrapNetRecv(struct mscclppSocket* sock, void* data, int size) -{ - int recvSize; - MSCCLPPCHECK(mscclppSocketRecv(sock, &recvSize, sizeof(int))); - if (recvSize > size) { - WARN("Message truncated : received %d bytes instead of %d", recvSize, size); - return mscclppInternalError; - } - MSCCLPPCHECK(mscclppSocketRecv(sock, data, std::min(recvSize, size))); - return mscclppSuccess; -} - -struct extInfo -{ - int rank; - int nranks; - union mscclppSocketAddress extAddressListenRoot; - union mscclppSocketAddress extAddressListen; -}; - #include +#include -static mscclppResult_t setFilesLimit() -{ - struct rlimit filesLimit; +#include +#include +#include +#include +#include +#include +#include +#include + +#include "api.h" +#include "checks.hpp" +#include "socket.h" +#include "utils.h" + +using namespace mscclpp; + +namespace { + +mscclppResult_t setFilesLimit() { + rlimit filesLimit; SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit"); filesLimit.rlim_cur = filesLimit.rlim_max; SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit"); return mscclppSuccess; } -static void* bootstrapRoot(void* rargs) -{ - struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs; - struct mscclppSocket* listenSock = args->listenSock; - uint64_t magic = args->magic; - mscclppResult_t res = mscclppSuccess; - int nranks = 0, c = 0; - struct extInfo info; - union mscclppSocketAddress* rankAddresses = NULL; - union mscclppSocketAddress* rankAddressesRoot = NULL; // for initial rank <-> root information exchange - union mscclppSocketAddress* zero = NULL; - MSCCLPPCHECKGOTO(mscclppCalloc(&zero, 1), res, out); +} // namespace + +/* Socket Interface Selection type */ +enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; + +struct UnexpectedMsg { + int peer; + int tag; + std::shared_ptr sock; +}; + +struct ExtInfo { + int rank; + int nRanks; + mscclppSocketAddress extAddressListenRoot; + mscclppSocketAddress extAddressListen; +}; + +struct UniqueIdInternal { + uint64_t magic; + union mscclppSocketAddress addr; +}; +static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId"); + +class Bootstrap::Impl { + public: + Impl(int rank, int nRanks); + ~Impl(); + void initialize(const UniqueId uniqueId); + void initialize(std::string ipPortPair); + void establishConnections(); + UniqueId createUniqueId(); + UniqueId getUniqueId() const; + int getRank(); + int getNranks(); + void allGather(void* allData, int size); + void send(void* data, int size, int peer, int tag); + void recv(void* data, int size, int peer, int tag); + void barrier(); + void close(); + + private: + UniqueIdInternal uniqueId_; + int rank_; + int nRanks_; + bool netInitialized; + mscclppSocket listenSock_; + mscclppSocket ringRecvSocket_; + mscclppSocket ringSendSocket_; + std::vector peerCommAddresses_; + std::list unexpectedMessages_; + std::vector barrierArr_; + volatile uint32_t* abortFlag_; + std::thread rootThread_; + char netIfName_[MAX_IF_NAME_SIZE + 1]; + mscclppSocketAddress netIfAddr_; + + void netSend(mscclppSocket* sock, const void* data, int size); + void netRecv(mscclppSocket* sock, void* data, int size); + + void bootstrapCreateRoot(); + void bootstrapRoot(mscclppSocket listenSock); + void getRemoteAddresses(mscclppSocket* listenSock, std::vector& rankAddresses, + std::vector& rankAddressesRoot, int& rank); + void sendHandleToPeer(int peer, const std::vector& rankAddresses, + const std::vector& rankAddressesRoot); + void netInit(std::string ipPortPair); +}; + +// UniqueId MscclppBootstrap::Impl::uniqueId_; + +Bootstrap::Impl::Impl(int rank, int nRanks) + : rank_(rank), + nRanks_(nRanks), + netInitialized(false), + peerCommAddresses_(nRanks, mscclppSocketAddress()), + barrierArr_(nRanks, 0), + abortFlag_(nullptr) {} + +UniqueId Bootstrap::Impl::getUniqueId() const { + UniqueId ret; + std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_)); + return ret; +} + +UniqueId Bootstrap::Impl::createUniqueId() { + netInit(""); + MSCCLPPTHROW(getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic))); + std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress)); + bootstrapCreateRoot(); + return getUniqueId(); +} + +int Bootstrap::Impl::getRank() { return rank_; } + +int Bootstrap::Impl::getNranks() { return nRanks_; } + +void Bootstrap::Impl::initialize(const UniqueId uniqueId) { + netInit(""); + + std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); + + establishConnections(); +} + +void Bootstrap::Impl::initialize(std::string ipPortPair) { + netInit(ipPortPair); + + uniqueId_.magic = 0xdeadbeef; + std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress)); + MSCCLPPTHROW(mscclppSocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str())); + + if (rank_ == 0) { + bootstrapCreateRoot(); + } + + establishConnections(); +} + +Bootstrap::Impl::~Impl() { + if (rootThread_.joinable()) { + rootThread_.join(); + } +} + +void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector& rankAddresses, + std::vector& rankAddressesRoot, int& rank) { + mscclppSocket sock; + ExtInfo info; + + mscclppSocketAddress zero; + std::memset(&zero, 0, sizeof(mscclppSocketAddress)); + MSCCLPPTHROW(mscclppSocketInit(&sock)); + MSCCLPPTHROW(mscclppSocketAccept(&sock, listenSock)); + netRecv(&sock, &info, sizeof(info)); + MSCCLPPTHROW(mscclppSocketClose(&sock)); + + if (this->nRanks_ != info.nRanks) { + throw mscclpp::Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " + + std::to_string(info.nRanks), + ErrorCode::InternalError); + } + + if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) { + throw mscclpp::Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) + + " has already checked in", + ErrorCode::InternalError); + } + + // Save the connection handle for that rank + rankAddressesRoot[info.rank] = info.extAddressListenRoot; + rankAddresses[info.rank] = info.extAddressListen; + rank = info.rank; +} + +void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector& rankAddresses, + const std::vector& rankAddressesRoot) { + mscclppSocket sock; + int next = (peer + 1) % this->nRanks_; + MSCCLPPTHROW(mscclppSocketInit(&sock, &rankAddressesRoot[peer], this->uniqueId_.magic, mscclppSocketTypeBootstrap)); + MSCCLPPTHROW(mscclppSocketConnect(&sock)); + netSend(&sock, &rankAddresses[next], sizeof(mscclppSocketAddress)); + MSCCLPPTHROW(mscclppSocketClose(&sock)); +} + +void Bootstrap::Impl::bootstrapCreateRoot() { + mscclppSocket listenSock; + + // mscclppSocket* listenSock = new mscclppSocket(); // TODO(saemal) make this a shared ptr + MSCCLPPTHROW( + mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0)); + MSCCLPPTHROW(mscclppSocketListen(&listenSock)); + MSCCLPPTHROW(mscclppSocketGetAddr(&listenSock, &uniqueId_.addr)); + auto lambda = [this, listenSock]() { this->bootstrapRoot(listenSock); }; + rootThread_ = std::thread(lambda); +} + +void Bootstrap::Impl::bootstrapRoot(mscclppSocket listenSock) { + int numCollected = 0; + std::vector rankAddresses(this->nRanks_, mscclppSocketAddress()); + // for initial rank <-> root information exchange + std::vector rankAddressesRoot(this->nRanks_, mscclppSocketAddress()); + + std::memset(rankAddresses.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_); + std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_); setFilesLimit(); TRACE(MSCCLPP_INIT, "BEGIN"); /* Receive addresses from all ranks */ do { - struct mscclppSocket sock; - MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), res, out); - MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, listenSock), res, out); - MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out); - MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out); - - if (c == 0) { - nranks = info.nranks; - MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddresses, nranks), res, out); - MSCCLPPCHECKGOTO(mscclppCalloc(&rankAddressesRoot, nranks), res, out); - } - - if (nranks != info.nranks) { - WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nranks); - goto out; - } - - if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union mscclppSocketAddress)) != 0) { - WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks); - goto out; - } - - // Save the connection handle for that rank - memcpy(rankAddressesRoot + info.rank, &info.extAddressListenRoot, sizeof(union mscclppSocketAddress)); - memcpy(rankAddresses + info.rank, &info.extAddressListen, sizeof(union mscclppSocketAddress)); - - ++c; - TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks); - } while (c < nranks); - TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", nranks); + int rank; + getRemoteAddresses(&listenSock, rankAddresses, rankAddressesRoot, rank); + ++numCollected; + TRACE(MSCCLPP_INIT, "Received connect from rank %d total %d/%d", rank, numCollected, this->nRanks_); + } while (numCollected < this->nRanks_); + TRACE(MSCCLPP_INIT, "COLLECTED ALL %d HANDLES", this->nRanks_); // Send the connect handle for the next rank in the AllGather ring - for (int r = 0; r < nranks; ++r) { - int next = (r + 1) % nranks; - struct mscclppSocket sock; - MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, rankAddressesRoot + r, magic, mscclppSocketTypeBootstrap), res, out); - MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), res, out); - MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, rankAddresses + next, sizeof(union mscclppSocketAddress)), res, out); - MSCCLPPCHECKGOTO(mscclppSocketClose(&sock), res, out); + for (int peer = 0; peer < this->nRanks_; ++peer) { + sendHandleToPeer(peer, rankAddresses, rankAddressesRoot); } - TRACE(MSCCLPP_INIT, "SENT OUT ALL %d HANDLES", nranks); - -out: - if (listenSock != NULL) { - mscclppSocketClose(listenSock); - free(listenSock); - } - if (rankAddresses) - free(rankAddresses); - if (rankAddressesRoot) - free(rankAddressesRoot); - if (zero) - free(zero); - free(rargs); TRACE(MSCCLPP_INIT, "DONE"); - return NULL; } -mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle) -{ - struct mscclppSocket* listenSock; - struct bootstrapRootArgs* args; - pthread_t thread; - - MSCCLPPCHECK(mscclppCalloc(&listenSock, 1)); - MSCCLPPCHECK(mscclppSocketInit(listenSock, &handle->addr, handle->magic, mscclppSocketTypeBootstrap, NULL, 0)); - MSCCLPPCHECK(mscclppSocketListen(listenSock)); - MSCCLPPCHECK(mscclppSocketGetAddr(listenSock, &handle->addr)); - - MSCCLPPCHECK(mscclppCalloc(&args, 1)); - args->listenSock = listenSock; - args->magic = handle->magic; - NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0); - mscclppSetThreadName(thread, "MSCCLPP BootstrapR"); - NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d - return mscclppSuccess; -} - -// #include -// #include - -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, const char* ip_port_pair) -{ - memset(handle, 0, sizeof(mscclppBootstrapHandle)); - const char* env = NULL; - - if (ip_port_pair) { - env = ip_port_pair; - } else { - env = getenv("MSCCLPP_COMM_ID"); - } - if (env) { - handle->magic = 0xdeadbeef; - - INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env); - if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) { - WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); - return mscclppInvalidArgument; +void Bootstrap::Impl::netInit(std::string ipPortPair) { + if (netInitialized) return; + if (!ipPortPair.empty()) { + mscclppSocketAddress remoteAddr; + if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) { + throw mscclpp::Error( + "Invalid ipPortPair, please use format: : or []: or :", + ErrorCode::InvalidUsage); + } + if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + throw mscclpp::Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError); } - if (isRoot) - MSCCLPPCHECK(bootstrapCreateRoot(handle)); } else { - MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); - memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress)); - MSCCLPPCHECK(bootstrapCreateRoot(handle)); + int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1); + if (ret <= 0) { + throw mscclpp::Error("Bootstrap : no socket interface found", ErrorCode::InternalError); + } } - // printf("addr = %s port = %d\n", inet_ntoa(handle->addr.sin.sin_addr), (int)ntohs(handle->addr.sin.sin_port)); - // printf("addr = %s\n", inet_ntoa((*(struct sockaddr_in*)&handle->addr.sa).sin_addr)); - return mscclppSuccess; + char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; + std::sprintf(line, " %s:", netIfName_); + mscclppSocketToString(&netIfAddr_, line + strlen(line)); + INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line); + netInitialized = true; } -struct unexConn -{ - int peer; - int tag; - struct mscclppSocket sock; - struct unexConn* next; -}; - -struct bootstrapState -{ - struct mscclppSocket listenSock; - struct mscclppSocket ringRecvSocket; - struct mscclppSocket ringSendSocket; - union mscclppSocketAddress* peerCommAddresses; - union mscclppSocketAddress* peerProxyAddresses; - struct unexConn* unexpectedConnections; - int cudaDev; - int rank; - int nranks; - uint64_t magic; - volatile uint32_t* abortFlag; -}; - -mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm) -{ - int rank = comm->rank; - int nranks = comm->nRanks; - struct bootstrapState* state; - struct mscclppSocket* proxySocket; +void Bootstrap::Impl::establishConnections() { mscclppSocketAddress nextAddr; - struct mscclppSocket sock, listenSockRoot; - struct extInfo info; + mscclppSocket sock, listenSockRoot; + ExtInfo info; - MSCCLPPCHECK(mscclppCalloc(&state, 1)); - state->rank = rank; - state->nranks = nranks; - state->abortFlag = comm->abortFlag; - comm->bootstrap = state; - comm->magic = state->magic = handle->magic; + TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_); - TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks); - - info.rank = rank; - info.nranks = nranks; + info.rank = this->rank_; + info.nRanks = this->nRanks_; + uint64_t magic = this->uniqueId_.magic; // Create socket for other ranks to contact me - MSCCLPPCHECK(mscclppSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, - comm->abortFlag)); - MSCCLPPCHECK(mscclppSocketListen(&state->listenSock)); - MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, &info.extAddressListen)); + MSCCLPPTHROW(mscclppSocketInit(&this->listenSock_, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + MSCCLPPTHROW(mscclppSocketListen(&this->listenSock_)); + MSCCLPPTHROW(mscclppSocketGetAddr(&this->listenSock_, &info.extAddressListen)); // Create socket for root to contact me - MSCCLPPCHECK( - mscclppSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); - MSCCLPPCHECK(mscclppSocketListen(&listenSockRoot)); - MSCCLPPCHECK(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); + MSCCLPPTHROW(mscclppSocketInit(&listenSockRoot, &netIfAddr_, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + MSCCLPPTHROW(mscclppSocketListen(&listenSockRoot)); + MSCCLPPTHROW(mscclppSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); // stagger connection times to avoid an overload of the root - if (nranks > 128) { - long msec = rank; - struct timespec tv; - tv.tv_sec = msec / 1000; - tv.tv_nsec = 1000000 * (msec % 1000); - TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, msec); + auto randomSleep = [](int rank) { + timespec tv; + tv.tv_sec = rank / 1000; + tv.tv_nsec = 1000000 * (rank % 1000); + TRACE(MSCCLPP_INIT, "rank %d delaying connection to root by %ld msec", rank, rank); (void)nanosleep(&tv, NULL); + }; + if (this->nRanks_ > 128) { + randomSleep(this->rank_); } + char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; + std::sprintf(line, " %s:", netIfName_); + mscclppSocketToString(&this->uniqueId_.addr, line + strlen(line)); + // send info on my listening socket to root - MSCCLPPCHECK(mscclppSocketInit(&sock, &handle->addr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); - MSCCLPPCHECK(mscclppSocketConnect(&sock)); - MSCCLPPCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); - MSCCLPPCHECK(mscclppSocketClose(&sock)); + MSCCLPPTHROW(mscclppSocketInit(&sock, &this->uniqueId_.addr, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + MSCCLPPTHROW(mscclppSocketConnect(&sock)); + netSend(&sock, &info, sizeof(info)); + MSCCLPPTHROW(mscclppSocketClose(&sock)); // get info on my "next" rank in the bootstrap ring from root - MSCCLPPCHECK(mscclppSocketInit(&sock)); - MSCCLPPCHECK(mscclppSocketAccept(&sock, &listenSockRoot)); - MSCCLPPCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union mscclppSocketAddress))); - MSCCLPPCHECK(mscclppSocketClose(&sock)); - MSCCLPPCHECK(mscclppSocketClose(&listenSockRoot)); + MSCCLPPTHROW(mscclppSocketInit(&sock)); + MSCCLPPTHROW(mscclppSocketAccept(&sock, &listenSockRoot)); + netRecv(&sock, &nextAddr, sizeof(mscclppSocketAddress)); + MSCCLPPTHROW(mscclppSocketClose(&sock)); + MSCCLPPTHROW(mscclppSocketClose(&listenSockRoot)); - MSCCLPPCHECK( - mscclppSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, mscclppSocketTypeBootstrap, comm->abortFlag)); - MSCCLPPCHECK(mscclppSocketConnect(&state->ringSendSocket)); + MSCCLPPTHROW( + mscclppSocketInit(&this->ringSendSocket_, &nextAddr, magic, mscclppSocketTypeBootstrap, this->abortFlag_)); + MSCCLPPTHROW(mscclppSocketConnect(&this->ringSendSocket_)); // Accept the connect request from the previous rank in the AllGather ring - MSCCLPPCHECK(mscclppSocketInit(&state->ringRecvSocket)); - MSCCLPPCHECK(mscclppSocketAccept(&state->ringRecvSocket, &state->listenSock)); + MSCCLPPTHROW(mscclppSocketInit(&this->ringRecvSocket_)); + MSCCLPPTHROW(mscclppSocketAccept(&this->ringRecvSocket_, &this->listenSock_)); // AllGather all listen handlers - MSCCLPPCHECK(mscclppCalloc(&state->peerCommAddresses, nranks)); - MSCCLPPCHECK(mscclppSocketGetAddr(&state->listenSock, state->peerCommAddresses + rank)); - MSCCLPPCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union mscclppSocketAddress))); + MSCCLPPTHROW(mscclppSocketGetAddr(&this->listenSock_, &this->peerCommAddresses_[rank_])); + allGather(this->peerCommAddresses_.data(), sizeof(mscclppSocketAddress)); - // Create the service proxy - MSCCLPPCHECK(mscclppCalloc(&state->peerProxyAddresses, nranks)); - - // proxy is aborted through a message; don't set abortFlag - MSCCLPPCHECK(mscclppCalloc(&proxySocket, 1)); - MSCCLPPCHECK( - mscclppSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, mscclppSocketTypeProxy, comm->abortFlag)); - MSCCLPPCHECK(mscclppSocketListen(proxySocket)); - MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, state->peerProxyAddresses + rank)); - MSCCLPPCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union mscclppSocketAddress))); - // MSCCLPPCHECK(mscclppProxyInit(comm, proxySocket, state->peerProxyAddresses)); - - TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks); - - return mscclppSuccess; + TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_); } -mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size) -{ - struct bootstrapState* state = (struct bootstrapState*)commState; - char* data = (char*)allData; - int rank = state->rank; - int nranks = state->nranks; +void Bootstrap::Impl::allGather(void* allData, int size) { + char* data = static_cast(allData); + int rank = this->rank_; + int nRanks = this->nRanks_; - TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nranks, size); + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d", rank, nRanks, size); /* Simple ring based AllGather * At each step i receive data from (rank-i-1) from left * and send previous step's data from (rank-i) to right */ - for (int i = 0; i < nranks - 1; i++) { - size_t rslice = (rank - i - 1 + nranks) % nranks; - size_t sslice = (rank - i + nranks) % nranks; + for (int i = 0; i < nRanks - 1; i++) { + size_t rSlice = (rank - i - 1 + nRanks) % nRanks; + size_t sSlice = (rank - i + nRanks) % nRanks; // Send slice to the right - MSCCLPPCHECK(bootstrapNetSend(&state->ringSendSocket, data + sslice * size, size)); + netSend(&this->ringSendSocket_, data + sSlice * size, size); // Recv slice from the left - MSCCLPPCHECK(bootstrapNetRecv(&state->ringRecvSocket, data + rslice * size, size)); + netRecv(&this->ringRecvSocket_, data + rSlice * size, size); } - TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); - return mscclppSuccess; + TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size); } -mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) -{ - mscclppResult_t ret = mscclppSuccess; - struct bootstrapState* state = (struct bootstrapState*)commState; - struct mscclppSocket sock; - - MSCCLPPCHECKGOTO(mscclppSocketInit(&sock, state->peerCommAddresses + peer, state->magic, mscclppSocketTypeBootstrap, - state->abortFlag), - ret, fail); - MSCCLPPCHECKGOTO(mscclppSocketConnect(&sock), ret, fail); - MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail); - MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail); - MSCCLPPCHECKGOTO(bootstrapNetSend(&sock, data, size), ret, fail); - -exit: - MSCCLPPCHECK(mscclppSocketClose(&sock)); - return ret; -fail: - goto exit; +void Bootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size) { + MSCCLPPTHROW(mscclppSocketSend(sock, &size, sizeof(int))); + MSCCLPPTHROW(mscclppSocketSend(sock, const_cast(data), size)); } -mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag) -{ - if (nranks == 1) - return mscclppSuccess; - TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag); - - /* Simple intra process barrier - * - * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet, - * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988" - */ - int data[1]; - for (int mask = 1; mask < nranks; mask <<= 1) { - int src = (rank - mask + nranks) % nranks; - int dst = (rank + mask) % nranks; - MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data))); - MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data))); +void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) { + int recvSize; + MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int))); + if (recvSize > size) { + throw mscclpp::Error( + "Message truncated : received " + std::to_string(recvSize) + " bytes instead of " + std::to_string(size), + ErrorCode::InvalidUsage); } - - TRACE(MSCCLPP_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag); - return mscclppSuccess; + MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size))); } -mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size) -{ - if (nranks == 1) - return mscclppSuccess; - char* data = (char*)allData; - TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - ENTER", rank, nranks, size); +void Bootstrap::Impl::send(void* data, int size, int peer, int tag) { + mscclppSocket sock; + MSCCLPPTHROW(mscclppSocketInit(&sock, &this->peerCommAddresses_[peer], this->uniqueId_.magic, + mscclppSocketTypeBootstrap, this->abortFlag_)); + MSCCLPPTHROW(mscclppSocketConnect(&sock)); + netSend(&sock, &this->rank_, sizeof(int)); + netSend(&sock, &tag, sizeof(int)); + netSend(&sock, data, size); - for (int i = 1; i < nranks; i++) { - int src = (rank - i + nranks) % nranks; - int dst = (rank + i) % nranks; - MSCCLPPCHECK(bootstrapSend(commState, ranks[dst], /*tag=*/i, data + rank * size, size)); - MSCCLPPCHECK(bootstrapRecv(commState, ranks[src], /*tag=*/i, data + src * size, size)); - } - - TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); - return mscclppSuccess; + MSCCLPPTHROW(mscclppSocketClose(&sock)); } -mscclppResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock) -{ - // New unex - struct unexConn* unex; - MSCCLPPCHECK(mscclppCalloc(&unex, 1)); - unex->peer = peer; - unex->tag = tag; - memcpy(&unex->sock, sock, sizeof(struct mscclppSocket)); - - // Enqueue - struct unexConn* list = state->unexpectedConnections; - if (list == NULL) { - state->unexpectedConnections = unex; - return mscclppSuccess; +void Bootstrap::Impl::recv(void* data, int size, int peer, int tag) { + // search over all unexpected messages + auto lambda = [peer, tag](const UnexpectedMsg& msg) { return msg.peer == peer && msg.tag == tag; }; + auto it = std::find_if(unexpectedMessages_.begin(), unexpectedMessages_.end(), lambda); + if (it != unexpectedMessages_.end()) { + // found a match + netRecv(it->sock.get(), data, size); + MSCCLPPTHROW(mscclppSocketClose(it->sock.get())); + unexpectedMessages_.erase(it); + return; } - while (list->next) - list = list->next; - list->next = unex; - return mscclppSuccess; -} - -mscclppResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct mscclppSocket* sock, - int* found) -{ - struct unexConn* elem = state->unexpectedConnections; - struct unexConn* prev = NULL; - *found = 0; - while (elem) { - if (elem->peer == peer && elem->tag == tag) { - if (prev == NULL) { - state->unexpectedConnections = elem->next; - } else { - prev->next = elem->next; - } - memcpy(sock, &elem->sock, sizeof(struct mscclppSocket)); - free(elem); - *found = 1; - return mscclppSuccess; - } - prev = elem; - elem = elem->next; - } - return mscclppSuccess; -} - -static void unexpectedFree(struct bootstrapState* state) -{ - struct unexConn* elem = state->unexpectedConnections; - struct unexConn* prev = NULL; - - while (elem) { - prev = elem; - elem = elem->next; - free(prev); - } - return; -} - -// We can't know who we'll receive from, so we need to receive everything at once -mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) -{ - mscclppResult_t ret = mscclppSuccess; - struct bootstrapState* state = (struct bootstrapState*)commState; - struct mscclppSocket sock; - int newPeer, newTag; - - // Search unexpected connections first - int found; - MSCCLPPCHECK(unexpectedDequeue(state, peer, tag, &sock, &found)); - if (found) { - MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); - goto exit; - } - - // Then look for new connections - while (1) { - MSCCLPPCHECKGOTO(mscclppSocketInit(&sock), ret, fail); - MSCCLPPCHECKGOTO(mscclppSocketAccept(&sock, &state->listenSock), ret, fail); - MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail); - MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail); + // didn't find one + while (true) { + auto sock = std::make_shared(); + int newPeer, newTag; + MSCCLPPTHROW(mscclppSocketInit(sock.get())); + MSCCLPPTHROW(mscclppSocketAccept(sock.get(), &this->listenSock_)); + netRecv(sock.get(), &newPeer, sizeof(int)); + netRecv(sock.get(), &newTag, sizeof(int)); if (newPeer == peer && newTag == tag) { - MSCCLPPCHECKGOTO(bootstrapNetRecv(&sock, ((char*)data), size), ret, fail); - goto exit; + netRecv(sock.get(), ((char*)data), size); + MSCCLPPTHROW(mscclppSocketClose(sock.get())); + return; } - // Unexpected connection. Save for later. - MSCCLPPCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail); + // Unexpected message. Save for later. + unexpectedMessages_.push_back({newPeer, newTag, sock}); } -exit: - MSCCLPPCHECK(mscclppSocketClose(&sock)); - return ret; -fail: - goto exit; } -mscclppResult_t bootstrapClose(void* commState) -{ - struct bootstrapState* state = (struct bootstrapState*)commState; - if (state->unexpectedConnections != NULL) { - unexpectedFree(state); - if (*state->abortFlag == 0) { - WARN("Unexpected connections are not empty"); - return mscclppInternalError; - } - } +void Bootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); } - MSCCLPPCHECK(mscclppSocketClose(&state->listenSock)); - MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket)); - MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket)); - - free(state->peerCommAddresses); - free(state); - - return mscclppSuccess; +void Bootstrap::Impl::close() { + MSCCLPPTHROW(mscclppSocketClose(&this->listenSock_)); + MSCCLPPTHROW(mscclppSocketClose(&this->ringSendSocket_)); + MSCCLPPTHROW(mscclppSocketClose(&this->ringRecvSocket_)); } -mscclppResult_t bootstrapAbort(void* commState) -{ - struct bootstrapState* state = (struct bootstrapState*)commState; - if (commState == NULL) - return mscclppSuccess; - MSCCLPPCHECK(mscclppSocketClose(&state->listenSock)); - MSCCLPPCHECK(mscclppSocketClose(&state->ringSendSocket)); - MSCCLPPCHECK(mscclppSocketClose(&state->ringRecvSocket)); - free(state->peerCommAddresses); - free(state->peerProxyAddresses); - free(state); - return mscclppSuccess; +MSCCLPP_API_CPP Bootstrap::Bootstrap(int rank, int nRanks) { + // pimpl_ = std::make_unique(ipPortPair, rank, nRanks, uniqueId); + pimpl_ = std::make_unique(rank, nRanks); } + +MSCCLPP_API_CPP UniqueId Bootstrap::createUniqueId() { return pimpl_->createUniqueId(); } + +MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const { return pimpl_->getUniqueId(); } + +MSCCLPP_API_CPP int Bootstrap::getRank() { return pimpl_->getRank(); } + +MSCCLPP_API_CPP int Bootstrap::getNranks() { return pimpl_->getNranks(); } + +MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); } + +MSCCLPP_API_CPP void Bootstrap::recv(void* data, int size, int peer, int tag) { pimpl_->recv(data, size, peer, tag); } + +MSCCLPP_API_CPP void Bootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); } + +MSCCLPP_API_CPP void Bootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); } + +MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->initialize(ipPortPair); } + +MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); } + +MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); } diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index b3998d91..b60815b6 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -5,25 +5,23 @@ ************************************************************************/ #include "socket.h" -#include "config.h" -#include "utils.h" - -#include #include #include +#include #include +#include "config.h" +#include "utils.h" + static mscclppResult_t socketProgressOpt(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset, - int block, int* closed) -{ + int block, int* closed) { int bytes = 0; *closed = 0; char* data = (char*)ptr; char line[SOCKET_NAME_MAXLEN + 1]; do { - if (op == MSCCLPP_SOCKET_RECV) - bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT); + if (op == MSCCLPP_SOCKET_RECV) bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT); if (op == MSCCLPP_SOCKET_SEND) bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); if (op == MSCCLPP_SOCKET_RECV && bytes == 0) { @@ -48,8 +46,7 @@ static mscclppResult_t socketProgressOpt(int op, struct mscclppSocket* sock, voi return mscclppSuccess; } -static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) -{ +static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) { int closed; MSCCLPPCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); if (closed) { @@ -60,10 +57,8 @@ static mscclppResult_t socketProgress(int op, struct mscclppSocket* sock, void* return mscclppSuccess; } -static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) -{ - while (*offset < size) - MSCCLPPCHECK(socketProgress(op, sock, ptr, size, offset)); +static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) { + while (*offset < size) MSCCLPPCHECK(socketProgress(op, sock, ptr, size, offset)); return mscclppSuccess; } @@ -71,10 +66,8 @@ static mscclppResult_t socketWait(int op, struct mscclppSocket* sock, void* ptr, * * Output: "IPv4/IPv6 address" */ -const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) -{ - if (buf == NULL || addr == NULL) - return NULL; +const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) { + if (buf == NULL || addr == NULL) return NULL; struct sockaddr* saddr = &addr->sa; if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0] = '\0'; @@ -90,68 +83,58 @@ const char* mscclppSocketToString(union mscclppSocketAddress* addr, char* buf, c return buf; } -static uint16_t socketToPort(union mscclppSocketAddress* addr) -{ +static uint16_t socketToPort(union mscclppSocketAddress* addr) { struct sockaddr* saddr = &addr->sa; return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port); } /* Allow the user to force the IPv4/IPv6 interface selection */ -static int envSocketFamily(void) -{ - int family = -1; // Family selection is not forced, will use first one found +static int envSocketFamily(void) { + int family = -1; // Family selection is not forced, will use first one found char* env = getenv("MSCCLPP_SOCKET_FAMILY"); - if (env == NULL) - return family; + if (env == NULL) return family; INFO(MSCCLPP_ENV, "MSCCLPP_SOCKET_FAMILY set by environment to %s", env); if (strcmp(env, "AF_INET") == 0) - family = AF_INET; // IPv4 + family = AF_INET; // IPv4 else if (strcmp(env, "AF_INET6") == 0) - family = AF_INET6; // IPv6 + family = AF_INET6; // IPv6 return family; } static int findInterfaces(const char* prefixList, char* names, union mscclppSocketAddress* addrs, int sock_family, - int maxIfNameSize, int maxIfs) -{ + int maxIfNameSize, int maxIfs) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN + 1]; #endif struct netIf userIfs[MAX_IFS]; bool searchNot = prefixList && prefixList[0] == '^'; - if (searchNot) - prefixList++; + if (searchNot) prefixList++; bool searchExact = prefixList && prefixList[0] == '='; - if (searchExact) - prefixList++; + if (searchExact) prefixList++; int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS); int found = 0; struct ifaddrs *interfaces, *interface; getifaddrs(&interfaces); for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) { - if (interface->ifa_addr == NULL) - continue; + if (interface->ifa_addr == NULL) continue; /* We only support IPv4 & IPv6 */ int family = interface->ifa_addr->sa_family; - if (family != AF_INET && family != AF_INET6) - continue; + if (family != AF_INET && family != AF_INET6) continue; TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Found interface %s:%s", interface->ifa_name, mscclppSocketToString((union mscclppSocketAddress*)interface->ifa_addr, line)); /* Allow the caller to force the socket family type */ - if (sock_family != -1 && family != sock_family) - continue; + if (sock_family != -1 && family != sock_family) continue; /* We also need to skip IPv6 loopback interfaces */ if (family == AF_INET6) { struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr); - if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) - continue; + if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue; } // check against user specified interfaces @@ -183,8 +166,7 @@ static int findInterfaces(const char* prefixList, char* names, union mscclppSock return found; } -static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* remote) -{ +static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* remote) { /* Check family first */ int family = local_if.ifa_addr->sa_family; if (family != remote->sa.sa_family) { @@ -207,8 +189,8 @@ static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* rem struct in6_addr& mask_in6 = mask->sin6_addr; struct in6_addr& remote_in6 = remote_addr.sin6_addr; bool same = true; - int len = 16; // IPv6 address is 16 unsigned char - for (int c = 0; c < len; c++) { // Network byte order is big-endian + int len = 16; // IPv6 address is 16 unsigned char + for (int c = 0; c < len; c++) { // Network byte order is big-endian char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c]; char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c]; if (c1 ^ c2) { @@ -228,8 +210,7 @@ static bool matchSubnet(struct ifaddrs local_if, union mscclppSocketAddress* rem } int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* localAddrs, - union mscclppSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) -{ + union mscclppSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN + 1]; #endif @@ -238,13 +219,11 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l struct ifaddrs *interfaces, *interface; getifaddrs(&interfaces); for (interface = interfaces; interface && !found; interface = interface->ifa_next) { - if (interface->ifa_addr == NULL) - continue; + if (interface->ifa_addr == NULL) continue; /* We only support IPv4 & IPv6 */ int family = interface->ifa_addr->sa_family; - if (family != AF_INET && family != AF_INET6) - continue; + if (family != AF_INET && family != AF_INET6) continue; // check against user specified interfaces if (!matchSubnet(*interface, remoteAddr)) { @@ -262,8 +241,7 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l interface->ifa_name, mscclppSocketToString(localAddrs + found, line), mscclppSocketToString(remoteAddr, line_a)); found++; - if (found == maxIfs) - break; + if (found == maxIfs) break; } if (found == 0) { @@ -273,8 +251,7 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l return found; } -mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, const char* ip_port_pair) -{ +mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, const char* ip_port_pair) { if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { WARN("Net : string is null"); return mscclppInvalidArgument; @@ -305,36 +282,34 @@ mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, c if (p->ai_family == AF_INET) { struct sockaddr_in& sin = ua->sin; memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in)); - sin.sin_family = AF_INET; // IPv4 + sin.sin_family = AF_INET; // IPv4 // inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address - sin.sin_port = htons(ni.port); // port + sin.sin_port = htons(ni.port); // port } else if (p->ai_family == AF_INET6) { struct sockaddr_in6& sin6 = ua->sin6; memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6)); - sin6.sin6_family = AF_INET6; // IPv6 - sin6.sin6_port = htons(ni.port); // port - sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete - sin6.sin6_scope_id = 0; // should be global scope, set to 0 + sin6.sin6_family = AF_INET6; // IPv6 + sin6.sin6_port = htons(ni.port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = 0; // should be global scope, set to 0 } else { WARN("Net : unsupported IP family"); return mscclppInvalidArgument; } - freeaddrinfo(p); // all done with this structure + freeaddrinfo(p); // all done with this structure } else { int i, j = -1, len = strlen(ip_port_pair); for (i = 1; i < len; i++) { - if (ip_port_pair[i] == '%') - j = i; - if (ip_port_pair[i] == ']') - break; + if (ip_port_pair[i] == '%') j = i; + if (ip_port_pair[i] == ']') break; } if (i == len) { WARN("Net : No valid [IPv6]:port pair found"); return mscclppInvalidArgument; } - bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope + bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ]; memset(ip_str, '\0', sizeof(ip_str)); @@ -343,21 +318,19 @@ mscclppResult_t mscclppSocketGetAddrFromString(union mscclppSocketAddress* ua, c strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1); strncpy(port_str, ip_port_pair + i + 2, len - i - 1); int port = atoi(port_str); - if (!global_scope) - strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name + if (!global_scope) strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name struct sockaddr_in6& sin6 = ua->sin6; - sin6.sin6_family = AF_INET6; // IPv6 - inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address - sin6.sin6_port = htons(port); // port - sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete - sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope + sin6.sin6_family = AF_INET6; // IPv6 + inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address + sin6.sin6_port = htons(port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope } return mscclppSuccess; } -int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) -{ +int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) { static int shownIfName = 0; int nIfs = 0; // Allow user to force the INET socket family selection @@ -367,8 +340,7 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, in if (env && strlen(env) > 1) { INFO(MSCCLPP_ENV, "MSCCLPP_SOCKET_IFNAME set by environment to %s", env); // Specified by user : find or fail - if (shownIfName++ == 0) - INFO(MSCCLPP_NET, "MSCCLPP_SOCKET_IFNAME set to %s", env); + if (shownIfName++ == 0) INFO(MSCCLPP_NET, "MSCCLPP_SOCKET_IFNAME set to %s", env); nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); } else { // Try to automatically pick the right one @@ -386,19 +358,15 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, in } } // Then look for anything else (but not docker or lo) - if (nIfs == 0) - nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); + if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); // Finally look for docker, then lo. - if (nIfs == 0) - nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); - if (nIfs == 0) - nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); + if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); + if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); } return nIfs; } -mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock) -{ +mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock) { if (sock == NULL) { WARN("mscclppSocketListen: pass NULL socket"); return mscclppInvalidArgument; @@ -438,20 +406,17 @@ mscclppResult_t mscclppSocketListen(struct mscclppSocket* sock) return mscclppSuccess; } -mscclppResult_t mscclppSocketGetAddr(struct mscclppSocket* sock, union mscclppSocketAddress* addr) -{ +mscclppResult_t mscclppSocketGetAddr(struct mscclppSocket* sock, union mscclppSocketAddress* addr) { if (sock == NULL) { WARN("mscclppSocketGetAddr: pass NULL socket"); return mscclppInvalidArgument; } - if (sock->state != mscclppSocketStateReady) - return mscclppInternalError; + if (sock->state != mscclppSocketStateReady) return mscclppInternalError; memcpy(addr, &sock->addr, sizeof(union mscclppSocketAddress)); return mscclppSuccess; } -static mscclppResult_t socketTryAccept(struct mscclppSocket* sock) -{ +static mscclppResult_t socketTryAccept(struct mscclppSocket* sock) { static bool timeInitialized = false; static mscclppTime_t initTime; if (!timeInitialized) { @@ -482,14 +447,12 @@ static mscclppResult_t socketTryAccept(struct mscclppSocket* sock) return mscclppSuccess; } -static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock) -{ +static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock) { uint64_t magic; enum mscclppSocketType type; int received = 0; MSCCLPPCHECK(mscclppSocketProgress(MSCCLPP_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); - if (received == 0) - return mscclppSuccess; + if (received == 0) return mscclppSuccess; MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); if (magic != sock->magic) { WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); @@ -514,8 +477,7 @@ static mscclppResult_t socketFinalizeAccept(struct mscclppSocket* sock) return mscclppSuccess; } -static mscclppResult_t socketStartConnect(struct mscclppSocket* sock) -{ +static mscclppResult_t socketStartConnect(struct mscclppSocket* sock) { static bool timeInitialized = false; static mscclppTime_t initTime; if (!timeInitialized) { @@ -543,8 +505,7 @@ static mscclppResult_t socketStartConnect(struct mscclppSocket* sock) return mscclppRemoteError; } usleep(SLEEP_INT); - if (++sock->connectRetries % 1000 == 0) - INFO(MSCCLPP_ALL, "Call to connect returned %s, retrying", strerror(errno)); + if (++sock->connectRetries % 1000 == 0) INFO(MSCCLPP_ALL, "Call to connect returned %s, retrying", strerror(errno)); return mscclppSuccess; } else { char line[SOCKET_NAME_MAXLEN + 1]; @@ -555,8 +516,7 @@ static mscclppResult_t socketStartConnect(struct mscclppSocket* sock) } } -static mscclppResult_t socketPollConnect(struct mscclppSocket* sock) -{ +static mscclppResult_t socketPollConnect(struct mscclppSocket* sock) { static bool timeInitialized = false; static mscclppTime_t initTime; if (!timeInitialized) { @@ -608,8 +568,7 @@ static mscclppResult_t socketPollConnect(struct mscclppSocket* sock) return mscclppSuccess; } -mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock) -{ +mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock) { if (sock == NULL) { WARN("mscclppSocketPollConnect: pass NULL socket"); return mscclppInvalidArgument; @@ -618,12 +577,10 @@ mscclppResult_t mscclppSocketPollConnect(struct mscclppSocket* sock) return mscclppSuccess; } -static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock) -{ +static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock) { int sent = 0; MSCCLPPCHECK(socketProgress(MSCCLPP_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); - if (sent == 0) - return mscclppSuccess; + if (sent == 0) return mscclppSuccess; MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); sent = 0; MSCCLPPCHECK(socketWait(MSCCLPP_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); @@ -631,8 +588,7 @@ static mscclppResult_t socketFinalizeConnect(struct mscclppSocket* sock) return mscclppSuccess; } -static mscclppResult_t socketProgressState(struct mscclppSocket* sock) -{ +static mscclppResult_t socketProgressState(struct mscclppSocket* sock) { if (sock->state == mscclppSocketStateAccepting) { MSCCLPPCHECK(socketTryAccept(sock)); } @@ -668,8 +624,7 @@ static mscclppResult_t socketProgressState(struct mscclppSocket* sock) // return mscclppSuccess; // } -mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock) -{ +mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN + 1]; #endif @@ -686,8 +641,7 @@ mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock) if (sock->state != mscclppSocketStateInitialized) { WARN("mscclppSocketConnect: wrong socket state %d", sock->state); - if (sock->state == mscclppSocketStateError) - return mscclppRemoteError; + if (sock->state == mscclppSocketStateError) return mscclppRemoteError; return mscclppInternalError; } TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Connecting to socket %s", mscclppSocketToString(&sock->addr, line)); @@ -701,25 +655,23 @@ mscclppResult_t mscclppSocketConnect(struct mscclppSocket* sock) (sock->state == mscclppSocketStateConnecting || sock->state == mscclppSocketStateConnectPolling || sock->state == mscclppSocketStateConnected)); - if (sock->abortFlag && *sock->abortFlag != 0) - return mscclppInternalError; + if (sock->abortFlag && *sock->abortFlag != 0) return mscclppInternalError; switch (sock->state) { - case mscclppSocketStateConnecting: - case mscclppSocketStateConnectPolling: - case mscclppSocketStateConnected: - case mscclppSocketStateReady: - return mscclppSuccess; - case mscclppSocketStateError: - return mscclppSystemError; - default: - WARN("mscclppSocketConnect: wrong socket state %d", sock->state); - return mscclppInternalError; + case mscclppSocketStateConnecting: + case mscclppSocketStateConnectPolling: + case mscclppSocketStateConnected: + case mscclppSocketStateReady: + return mscclppSuccess; + case mscclppSocketStateError: + return mscclppSystemError; + default: + WARN("mscclppSocketConnect: wrong socket state %d", sock->state); + return mscclppInternalError; } } -mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSocket* listenSock) -{ +mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSocket* listenSock) { mscclppResult_t ret = mscclppSuccess; if (listenSock == NULL || sock == NULL) { @@ -747,35 +699,32 @@ mscclppResult_t mscclppSocketAccept(struct mscclppSocket* sock, struct mscclppSo } while (sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) && (sock->state == mscclppSocketStateAccepting || sock->state == mscclppSocketStateAccepted)); - if (sock->abortFlag && *sock->abortFlag != 0) - return mscclppInternalError; + if (sock->abortFlag && *sock->abortFlag != 0) return mscclppInternalError; switch (sock->state) { - case mscclppSocketStateAccepting: - case mscclppSocketStateAccepted: - case mscclppSocketStateReady: - ret = mscclppSuccess; - break; - case mscclppSocketStateError: - ret = mscclppSystemError; - break; - default: - WARN("mscclppSocketAccept: wrong socket state %d", sock->state); - ret = mscclppInternalError; - break; + case mscclppSocketStateAccepting: + case mscclppSocketStateAccepted: + case mscclppSocketStateReady: + ret = mscclppSuccess; + break; + case mscclppSocketStateError: + ret = mscclppSystemError; + break; + default: + WARN("mscclppSocketAccept: wrong socket state %d", sock->state); + ret = mscclppInternalError; + break; } exit: return ret; } -mscclppResult_t mscclppSocketInit(struct mscclppSocket* sock, union mscclppSocketAddress* addr, uint64_t magic, - enum mscclppSocketType type, volatile uint32_t* abortFlag, int asyncFlag) -{ +mscclppResult_t mscclppSocketInit(struct mscclppSocket* sock, const mscclppSocketAddress* addr, uint64_t magic, + enum mscclppSocketType type, volatile uint32_t* abortFlag, int asyncFlag) { mscclppResult_t ret = mscclppSuccess; - if (sock == NULL) - goto exit; + if (sock == NULL) goto exit; sock->connectRetries = 0; sock->acceptRetries = 0; sock->abortFlag = abortFlag; @@ -824,8 +773,7 @@ fail: goto exit; } -mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) -{ +mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void* ptr, int size, int* offset) { if (sock == NULL) { WARN("mscclppSocketProgress: pass NULL socket"); return mscclppInvalidArgument; @@ -843,8 +791,7 @@ mscclppResult_t mscclppSocketProgress(int op, struct mscclppSocket* sock, void* // return mscclppSuccess; // } -mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int size) -{ +mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int size) { int offset = 0; if (sock == NULL) { WARN("mscclppSocketSend: pass NULL socket"); @@ -858,8 +805,7 @@ mscclppResult_t mscclppSocketSend(struct mscclppSocket* sock, void* ptr, int siz return mscclppSuccess; } -mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int size) -{ +mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int size) { int offset = 0; if (sock == NULL) { WARN("mscclppSocketRecv: pass NULL socket"); @@ -888,11 +834,9 @@ mscclppResult_t mscclppSocketRecv(struct mscclppSocket* sock, void* ptr, int siz // return mscclppSuccess; // } -mscclppResult_t mscclppSocketClose(struct mscclppSocket* sock) -{ +mscclppResult_t mscclppSocketClose(struct mscclppSocket* sock) { if (sock != NULL) { - if (sock->fd >= 0) - close(sock->fd); + if (sock->fd >= 0) close(sock->fd); sock->state = mscclppSocketStateClosed; sock->fd = -1; } diff --git a/src/c_style_remnants.cc b/src/c_style_remnants.cc new file mode 100644 index 00000000..98b6273f --- /dev/null +++ b/src/c_style_remnants.cc @@ -0,0 +1,39 @@ +#include "api.h" +#include "config.h" +#include "debug.h" +#include "mscclpp.h" + +MSCCLPP_API void mscclppDefaultLogHandler(const char* msg) { mscclppDebugDefaultLogHandler(msg); } + +MSCCLPP_API mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) { + return mscclppDebugSetLogHandler(handler); +} + +MSCCLPP_API mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) { + mscclppConfig* config = mscclppConfig::getInstance(); + config->setBootstrapConnectionTimeoutConfig(timeout); + return mscclppSuccess; +} + +MSCCLPP_API const char* mscclppGetErrorString(mscclppResult_t code) { + switch (code) { + case mscclppSuccess: + return "no error"; + case mscclppUnhandledCudaError: + return "unhandled cuda error"; + case mscclppSystemError: + return "unhandled system error"; + case mscclppInternalError: + return "internal error"; + case mscclppInvalidArgument: + return "invalid argument"; + case mscclppInvalidUsage: + return "invalid usage"; + case mscclppRemoteError: + return "remote process exited or there was a network error"; + case mscclppInProgress: + return "MSCCL++ operation in progress"; + default: + return "unknown result code"; + } +} diff --git a/src/channel.cc b/src/channel.cc new file mode 100644 index 00000000..f9564f79 --- /dev/null +++ b/src/channel.cc @@ -0,0 +1,27 @@ +#include + +#include "api.h" +#include "checks.hpp" +#include "debug.h" +#include "utils.h" + +namespace mscclpp { +namespace channel { + +MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) + : communicator_(communicator), + proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { + int cudaDevice; + CUDATHROW(cudaGetDevice(&cudaDevice)); + MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode)); +} + +MSCCLPP_API_CPP void DeviceChannelService::bindThread() { + if (deviceNumaNode >= 0) { + MSCCLPPTHROW(numaBind(deviceNumaNode)); + INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode); + } +} + +} // namespace channel +} // namespace mscclpp diff --git a/src/communicator.cc b/src/communicator.cc new file mode 100644 index 00000000..b7b6923d --- /dev/null +++ b/src/communicator.cc @@ -0,0 +1,139 @@ +#include "communicator.hpp" + +#include +#include + +#include "api.h" +#include "checks.hpp" +#include "connection.hpp" +#include "debug.h" +#include "registered_memory.hpp" +#include "utils.h" + +namespace mscclpp { + +Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) { + rankToHash_.resize(bootstrap->getNranks()); + auto hostHash = getHostHash(); + INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); + rankToHash_[bootstrap->getRank()] = hostHash; + bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); + + CUDATHROW(cudaStreamCreateWithFlags(&ipcStream_, cudaStreamNonBlocking)); +} + +Communicator::Impl::~Impl() { + ibContexts_.clear(); + + cudaStreamDestroy(ipcStream_); +} + +IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { + // Find IB context or create it + auto it = ibContexts_.find(ibTransport); + if (it == ibContexts_.end()) { + auto ibDev = getIBDeviceName(ibTransport); + ibContexts_[ibTransport] = std::make_unique(ibDev); + return ibContexts_[ibTransport].get(); + } else { + return it->second.get(); + } +} + +cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; } + +MSCCLPP_API_CPP Communicator::~Communicator() = default; + +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) + : pimpl(std::make_unique(bootstrap)) {} + +MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrapper() { return pimpl->bootstrap_; } + +MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { + return RegisteredMemory( + std::make_shared(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl)); +} + +struct MemorySender : public Setuppable { + MemorySender(RegisteredMemory memory, int remoteRank, int tag) + : memory_(memory), remoteRank_(remoteRank), tag_(tag) {} + + void beginSetup(std::shared_ptr bootstrap) override { + bootstrap->send(memory_.serialize(), remoteRank_, tag_); + } + + RegisteredMemory memory_; + int remoteRank_; + int tag_; +}; + +MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) { + onSetup(std::make_shared(memory, remoteRank, tag)); +} + +struct MemoryReceiver : public Setuppable { + MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} + + void endSetup(std::shared_ptr bootstrap) override { + std::vector data; + bootstrap->recv(data, remoteRank_, tag_); + memoryPromise_.set_value(RegisteredMemory::deserialize(data)); + } + + std::promise memoryPromise_; + int remoteRank_; + int tag_; +}; + +MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) { + auto memoryReceiver = std::make_shared(remoteRank, tag); + onSetup(memoryReceiver); + return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); +} + +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) { + std::shared_ptr conn; + if (transport == Transport::CudaIpc) { + // sanity check: make sure the IPC connection is being made within a node + if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) { + std::stringstream ss; + ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex + << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" + << " != " << pimpl->bootstrap_->getRank() << "(" << std::hex + << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; + throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage); + } + auto cudaIpcConn = std::make_shared(remoteRank, tag, pimpl->getIpcStream()); + conn = cudaIpcConn; + INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", + pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank, + pimpl->rankToHash_[remoteRank]); + } else if (AllIBTransports.has(transport)) { + auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); + conn = ibConn; + INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", + pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); + } else { + throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError); + } + pimpl->connections_.push_back(conn); + onSetup(conn); + return conn; +} + +MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr setuppable) { + pimpl->toSetup_.push_back(setuppable); +} + +MSCCLPP_API_CPP void Communicator::setup() { + for (auto& setuppable : pimpl->toSetup_) { + setuppable->beginSetup(pimpl->bootstrap_); + } + for (auto& setuppable : pimpl->toSetup_) { + setuppable->endSetup(pimpl->bootstrap_); + } + pimpl->toSetup_.clear(); +} + +} // namespace mscclpp diff --git a/src/config.cc b/src/config.cc index 069bfbe0..e4640216 100644 --- a/src/config.cc +++ b/src/config.cc @@ -2,17 +2,8 @@ mscclppConfig mscclppConfig::_instance; -mscclppConfig* mscclppConfig::getInstance() -{ - return &_instance; -} +mscclppConfig* mscclppConfig::getInstance() { return &_instance; } -time_t mscclppConfig::getBootstrapConnectionTimeoutConfig() -{ - return bootstrapConnectionTimeout; -} +time_t mscclppConfig::getBootstrapConnectionTimeoutConfig() { return bootstrapConnectionTimeout; } -void mscclppConfig::setBootstrapConnectionTimeoutConfig(time_t timeout) -{ - bootstrapConnectionTimeout = timeout; -} +void mscclppConfig::setBootstrapConnectionTimeoutConfig(time_t timeout) { bootstrapConnectionTimeout = timeout; } diff --git a/src/connection.cc b/src/connection.cc new file mode 100644 index 00000000..1fce9b89 --- /dev/null +++ b/src/connection.cc @@ -0,0 +1,151 @@ +#include "connection.hpp" + +#include + +#include "checks.hpp" +#include "infiniband/verbs.h" +#include "npkit/npkit.h" +#include "registered_memory.hpp" +#include "utils.hpp" + +namespace mscclpp { + +void validateTransport(RegisteredMemory mem, Transport transport) { + if (!mem.transports().has(transport)) { + throw Error("RegisteredMemory does not support this transport", ErrorCode::InvalidUsage); + } +} + +// Connection + +std::shared_ptr Connection::getRegisteredMemoryImpl(RegisteredMemory& mem) { return mem.pimpl; } + +// ConnectionBase + +ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} + +int ConnectionBase::remoteRank() { return remoteRank_; } + +int ConnectionBase::tag() { return tag_; } + +// CudaIpcConnection + +CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream) + : ConnectionBase(remoteRank, tag), stream_(stream) {} + +CudaIpcConnection::~CudaIpcConnection() {} + +Transport CudaIpcConnection::transport() { return Transport::CudaIpc; } + +Transport CudaIpcConnection::remoteTransport() { return Transport::CudaIpc; } + +void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + char* dstPtr = (char*)dst.data(); + char* srcPtr = (char*)src.data(); + + CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream_)); + INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size); + + // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); +} + +void CudaIpcConnection::flush() { + CUDATHROW(cudaStreamSynchronize(stream_)); + // npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); +} + +// IBConnection + +IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) + : ConnectionBase(remoteRank, tag), + transport_(transport), + remoteTransport_(Transport::Unknown), + numSignaledSends(0) { + qp = commImpl.getIbContext(transport)->createQp(); +} + +Transport IBConnection::transport() { return transport_; } + +Transport IBConnection::remoteTransport() { return remoteTransport_; } + +void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); + if (dstTransportInfo.ibLocal) { + throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage); + } + auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport()); + if (!srcTransportInfo.ibLocal) { + throw Error("src is remote, which is not supported", ErrorCode::InvalidUsage); + } + + auto dstMrInfo = dstTransportInfo.ibMrInfo; + auto srcMr = srcTransportInfo.ibMr; + + qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, + /*signaled=*/true); + numSignaledSends++; + qp->postSend(); + INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, + (uint8_t*)dstMrInfo.addr + dstOffset, size); + // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); +} + +void IBConnection::flush() { + Timer timer; + while (numSignaledSends) { + int wcNum = qp->pollCq(); + if (wcNum < 0) { + throw mscclpp::IbError("pollCq failed: error no " + std::to_string(errno), errno); + } + + auto elapsed = timer.elapsed(); + if (elapsed > MSCCLPP_POLLING_WAIT) { + throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " + + std::to_string(numSignaledSends) + " signals", + ErrorCode::InternalError); + } + for (int i = 0; i < wcNum; ++i) { + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); + if (wc->status != IBV_WC_SUCCESS) { + throw mscclpp::IbError("pollCq failed: status " + std::to_string(wc->status), wc->status); + } + if (wc->opcode == IBV_WC_RDMA_WRITE) { + numSignaledSends--; + } + } + } + // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); +} + +void IBConnection::beginSetup(std::shared_ptr bootstrap) { + std::vector ibQpTransport; + std::copy_n(reinterpret_cast(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport)); + std::copy_n(reinterpret_cast(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport)); + + bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); +} + +void IBConnection::endSetup(std::shared_ptr bootstrap) { + std::vector ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport)); + bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); + + IbQpInfo qpInfo; + auto it = ibQpTransport.begin(); + std::copy_n(it, sizeof(qpInfo), reinterpret_cast(&qpInfo)); + it += sizeof(qpInfo); + std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast(&remoteTransport_)); + it += sizeof(qpInfo); + + qp->rtr(qpInfo); + qp->rts(); +} + +} // namespace mscclpp diff --git a/src/debug.cc b/src/debug.cc index a3807d3e..9841a4b7 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -5,17 +5,19 @@ ************************************************************************/ #include "debug.h" -#include "core.h" + +#include #include #include #include +#include int mscclppDebugLevel = -1; static int pid = -1; static char hostname[1024]; thread_local int mscclppDebugNoWarn = 0; -char mscclppLastError[1024] = ""; // Global string for the last error in human readable form -uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT +char mscclppLastError[1024] = ""; // Global string for the last error in human readable form +uint64_t mscclppDebugMask = MSCCLPP_INIT; // Default debug sub-system mask is INIT FILE* mscclppDebugFile = stdout; mscclppLogHandler_t mscclppDebugLogHandler = NULL; pthread_mutex_t mscclppDebugLock = PTHREAD_MUTEX_INITIALIZER; @@ -23,13 +25,9 @@ std::chrono::steady_clock::time_point mscclppEpoch; static __thread int tid = -1; -void mscclppDebugDefaultLogHandler(const char* msg) -{ - fwrite(msg, 1, strlen(msg), mscclppDebugFile); -} +void mscclppDebugDefaultLogHandler(const char* msg) { fwrite(msg, 1, strlen(msg), mscclppDebugFile); } -void mscclppDebugInit() -{ +void mscclppDebugInit() { pthread_mutex_lock(&mscclppDebugLock); if (mscclppDebugLevel != -1) { pthread_mutex_unlock(&mscclppDebugLock); @@ -120,33 +118,32 @@ void mscclppDebugInit() continue; } switch (mscclppDebugFileEnv[c++]) { - case '%': // Double % - *dfn++ = '%'; - break; - case 'h': // %h = hostname - dfn += snprintf(dfn, PATH_MAX, "%s", hostname); - break; - case 'p': // %p = pid - dfn += snprintf(dfn, PATH_MAX, "%d", pid); - break; - default: // Echo everything we don't understand - *dfn++ = '%'; - *dfn++ = mscclppDebugFileEnv[c - 1]; - break; + case '%': // Double % + *dfn++ = '%'; + break; + case 'h': // %h = hostname + dfn += snprintf(dfn, PATH_MAX, "%s", hostname); + break; + case 'p': // %p = pid + dfn += snprintf(dfn, PATH_MAX, "%d", pid); + break; + default: // Echo everything we don't understand + *dfn++ = '%'; + *dfn++ = mscclppDebugFileEnv[c - 1]; + break; } } *dfn = '\0'; if (debugFn[0] != '\0') { FILE* file = fopen(debugFn, "w"); if (file != nullptr) { - setbuf(file, nullptr); // disable buffering + setbuf(file, nullptr); // disable buffering mscclppDebugFile = file; } } } - if (mscclppDebugLogHandler == NULL) - mscclppDebugLogHandler = mscclppDefaultLogHandler; + if (mscclppDebugLogHandler == NULL) mscclppDebugLogHandler = mscclppDefaultLogHandler; mscclppEpoch = std::chrono::steady_clock::now(); __atomic_store_n(&mscclppDebugLevel, tempNcclDebugLevel, __ATOMIC_RELEASE); @@ -158,10 +155,8 @@ void mscclppDebugInit() * they can share the debugging mechanisms and output files */ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char* filefunc, int line, const char* fmt, - ...) -{ - if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) - mscclppDebugInit(); + ...) { + if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) mscclppDebugInit(); if (mscclppDebugNoWarn != 0 && level == MSCCLPP_LOG_WARN) { level = MSCCLPP_LOG_INFO; flags = mscclppDebugNoWarn; @@ -175,8 +170,7 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char va_end(vargs); pthread_mutex_unlock(&mscclppDebugLock); } - if (mscclppDebugLevel < level || ((flags & mscclppDebugMask) == 0)) - return; + if (mscclppDebugLevel < level || ((flags & mscclppDebugMask) == 0)) return; if (tid == -1) { tid = syscall(SYS_gettid); @@ -217,27 +211,19 @@ void mscclppDebugLog(mscclppDebugLogLevel level, unsigned long flags, const char } } -mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler) -{ - if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) - mscclppDebugInit(); - if (handler == NULL) - return mscclppInvalidArgument; +mscclppResult_t mscclppDebugSetLogHandler(mscclppLogHandler_t handler) { + if (__atomic_load_n(&mscclppDebugLevel, __ATOMIC_ACQUIRE) == -1) mscclppDebugInit(); + if (handler == NULL) return mscclppInvalidArgument; pthread_mutex_lock(&mscclppDebugLock); mscclppDebugLogHandler = handler; pthread_mutex_unlock(&mscclppDebugLock); return mscclppSuccess; } -MSCCLPP_PARAM(SetThreadName, "SET_THREAD_NAME", 0); - -void mscclppSetThreadName(pthread_t thread, const char* fmt, ...) -{ +void mscclppSetThreadName(pthread_t thread, const char* fmt, ...) { // pthread_setname_np is nonstandard GNU extension // needs the following feature test macro #ifdef _GNU_SOURCE - if (mscclppParamSetThreadName() != 1) - return; char threadName[MSCCLPP_THREAD_NAMELEN]; va_list vargs; va_start(vargs, fmt); diff --git a/src/epoch.cc b/src/epoch.cc new file mode 100644 index 00000000..7f29c92e --- /dev/null +++ b/src/epoch.cc @@ -0,0 +1,69 @@ +#include + +#include "alloc.h" +#include "api.h" +#include "checks.hpp" + +namespace mscclpp { + +BaseEpoch::BaseEpoch(std::shared_ptr connection) : connection_(connection) {} + +void BaseEpoch::setup(Communicator& communicator) { + localEpochIdsRegMem_ = communicator.registerMemory(epochIds_, sizeof(epochIds_), connection_->transport()); + communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection_->remoteRank(), connection_->tag()); + remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection_->remoteRank(), connection_->tag()); +} + +void BaseEpoch::signal() { + connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica), localEpochIdsRegMem_, + offsetof(EpochIds, outbound), sizeof(epochIds_)); +} + +MSCCLPP_API_CPP DeviceEpoch::DeviceEpoch(Communicator& communicator, std::shared_ptr connection) + : BaseEpoch(connection) { + MSCCLPPTHROW(mscclppCudaCalloc(&epochIds_, 1)); + MSCCLPPTHROW(mscclppCudaCalloc(&expectedInboundEpochId_, 1)); + setup(communicator); +} + +MSCCLPP_API_CPP DeviceEpoch::~DeviceEpoch() { + mscclppCudaFree(epochIds_); + mscclppCudaFree(expectedInboundEpochId_); +} + +MSCCLPP_API_CPP void DeviceEpoch::signal() { BaseEpoch::signal(); } + +MSCCLPP_API_CPP DeviceEpoch::DeviceHandle DeviceEpoch::deviceHandle() { + DeviceEpoch::DeviceHandle device; + device.epochIds = epochIds_; + device.expectedInboundEpochId = expectedInboundEpochId_; + return device; +} + +MSCCLPP_API_CPP HostEpoch::HostEpoch(Communicator& communicator, std::shared_ptr connection) + : BaseEpoch(connection) { + if (connection->transport() == Transport::CudaIpc) { + throw Error("HostEpoch cannot be used with CudaIpc transport", ErrorCode::InvalidUsage); + } + epochIds_ = new EpochIds(); + expectedInboundEpochId_ = new uint64_t(); + setup(communicator); +} + +MSCCLPP_API_CPP HostEpoch::~HostEpoch() { + delete epochIds_; + delete expectedInboundEpochId_; +} + +MSCCLPP_API_CPP void HostEpoch::increamentAndSignal() { + *(volatile uint64_t*)&(epochIds_->outbound) += 1; + signal(); +} + +MSCCLPP_API_CPP void HostEpoch::wait() { + (*expectedInboundEpochId_) += 1; + while (*(volatile uint64_t*)&(epochIds_->inboundReplica) < (*expectedInboundEpochId_)) + ; +} + +} // namespace mscclpp diff --git a/src/errors.cc b/src/errors.cc new file mode 100644 index 00000000..50d7a2ef --- /dev/null +++ b/src/errors.cc @@ -0,0 +1,48 @@ +#include +#include + +#include "api.h" + +namespace mscclpp { + +std::string errorToString(enum ErrorCode error) { + switch (error) { + case ErrorCode::SystemError: + return "SystemError"; + case ErrorCode::InternalError: + return "InternalError"; + case ErrorCode::InvalidUsage: + return "InvalidUsage"; + default: + return "UnknownError"; + } +} + +BaseError::BaseError(std::string message, int errorCode) + : std::runtime_error(""), message_(message), errorCode_(errorCode) {} + +BaseError::BaseError(int errorCode) : std::runtime_error(""), errorCode_(errorCode) {} + +int BaseError::getErrorCode() const { return errorCode_; } + +const char* BaseError::what() const noexcept { return message_.c_str(); } + +MSCCLPP_API_CPP Error::Error(std::string message, ErrorCode errorCode) : BaseError(static_cast(errorCode)) { + message_ = message + " (Mscclpp failure: " + errorToString(errorCode) + ")"; +} + +MSCCLPP_API_CPP CudaError::CudaError(std::string message, cudaError_t errorCode) : BaseError(errorCode) { + message_ = message + " (Cuda failure: " + cudaGetErrorString(errorCode) + ")"; +} + +MSCCLPP_API_CPP CuError::CuError(std::string message, CUresult errorCode) : BaseError(errorCode) { + const char* errStr; + cuGetErrorString(errorCode, &errStr); + message_ = message + " (Cu failure: " + errStr + ")"; +} + +MSCCLPP_API_CPP IbError::IbError(std::string message, int errorCode) : BaseError(errorCode) { + message_ = message + " (Ib failure: " + std::strerror(errorCode) + ")"; +} + +}; // namespace mscclpp diff --git a/src/fifo.cc b/src/fifo.cc new file mode 100644 index 00000000..e4571254 --- /dev/null +++ b/src/fifo.cc @@ -0,0 +1,68 @@ +#include +#include + +#include +#include + +#include "alloc.h" +#include "api.h" +#include "checks.hpp" + +namespace mscclpp { + +struct HostProxyFifo::Impl { + DeviceProxyFifo deviceFifo; + + // allocated on the host. Only accessed by the host. This is a copy of the + // value pointed to by fifoTailDev and the invariant is that + // *fifoTailDev <= hostTail. Meaning that host's copy of tail is + // always ahead of the device's copy and host updates the device's copy + // only when it is needed. Therefore, hostTail is the "true" tail + // and fifoTailDev is a "stale" tail. See proxy.cc to undertand how + // these updates are pushed to the device. + uint64_t hostTail; + + // for transferring fifo tail + cudaStream_t stream; +}; + +MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() { + pimpl = std::make_unique(); + MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1)); + MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE)); + MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.tailReplica, 1)); + CUDATHROW(cudaStreamCreateWithFlags(&pimpl->stream, cudaStreamNonBlocking)); + pimpl->hostTail = 0; +} + +MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() { + mscclppCudaFree(pimpl->deviceFifo.head); + mscclppCudaHostFree(pimpl->deviceFifo.triggers); + mscclppCudaFree(pimpl->deviceFifo.tailReplica); + cudaStreamDestroy(pimpl->stream); +} + +MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) { + __m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]); + _mm_store_si128((__m128i*)trigger, xmm0); +} + +MSCCLPP_API_CPP void HostProxyFifo::pop() { + *(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0; + (pimpl->hostTail)++; +} + +MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) { + // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure + // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush + // request. + CUDATHROW(cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, + pimpl->stream)); + if (sync) { + CUDATHROW(cudaStreamSynchronize(pimpl->stream)); + } +} + +MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() { return pimpl->deviceFifo; } + +} // namespace mscclpp diff --git a/src/gdr.cc b/src/gdr.cc deleted file mode 100644 index 95cd6870..00000000 --- a/src/gdr.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "gdr.h" - -// Used to make the GDR library calls thread safe -pthread_mutex_t gdrLock = PTHREAD_MUTEX_INITIALIZER; - -gdr_t wrap_gdr_open(void) -{ - return gdr_open(); -} - -mscclppResult_t wrap_gdr_close(gdr_t g) -{ - int ret = gdr_close(g); - if (ret != 0) { - WARN("gdr_close() failed: %d", ret); - return mscclppSystemError; - } - return mscclppSuccess; -} - -mscclppResult_t wrap_gdr_pin_buffer(gdr_t g, unsigned long addr, size_t size, uint64_t p2p_token, uint32_t va_space, - gdr_mh_t* handle) -{ - int ret; - GDRLOCKCALL(gdr_pin_buffer(g, addr, size, p2p_token, va_space, handle), ret); - if (ret != 0) { - WARN("gdr_pin_buffer(addr %lx, size %zi) failed: %d", addr, size, ret); - return mscclppSystemError; - } - return mscclppSuccess; -} - -mscclppResult_t wrap_gdr_unpin_buffer(gdr_t g, gdr_mh_t handle) -{ - int ret; - GDRLOCKCALL(gdr_unpin_buffer(g, handle), ret); - if (ret != 0) { - WARN("gdr_unpin_buffer(handle %lx) failed: %d", handle.h, ret); - return mscclppSystemError; - } - return mscclppSuccess; -} - -mscclppResult_t wrap_gdr_get_info(gdr_t g, gdr_mh_t handle, gdr_info_t* info) -{ - int ret; - GDRLOCKCALL(gdr_get_info(g, handle, info), ret); - if (ret != 0) { - WARN("gdr_get_info(handle %lx) failed: %d", handle.h, ret); - return mscclppSystemError; - } - return mscclppSuccess; -} - -mscclppResult_t wrap_gdr_map(gdr_t g, gdr_mh_t handle, void** va, size_t size) -{ - int ret; - GDRLOCKCALL(gdr_map(g, handle, va, size), ret); - if (ret != 0) { - WARN("gdr_map(handle %lx, size %zi) failed: %d", handle.h, size, ret); - return mscclppSystemError; - } - return mscclppSuccess; -} - -mscclppResult_t wrap_gdr_unmap(gdr_t g, gdr_mh_t handle, void* va, size_t size) -{ - int ret; - GDRLOCKCALL(gdr_unmap(g, handle, va, size), ret); - if (ret != 0) { - WARN("gdr_unmap(handle %lx, va %p, size %zi) failed: %d", handle.h, va, size, ret); - return mscclppSystemError; - } - return mscclppSuccess; -} diff --git a/src/ib.cc b/src/ib.cc index 18a4fa02..7eed6b5e 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -1,310 +1,154 @@ +#include "ib.hpp" + +#include +#include +#include + #include #include #include -#include -#include -#include +#include +#include +#include #include "alloc.h" -#include "comm.h" +#include "api.h" +#include "checks.hpp" #include "debug.h" -#include "ib.h" -static int getIbDevNumaNode(const char* ibDevPath) -{ - if (ibDevPath == NULL) { - WARN("ibDevPath is NULL"); - return -1; - } - const char* postfix = "/device/numa_node"; - FILE* fp = NULL; - char* filePath = NULL; - int node = -1; - int res; - if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) { - WARN("mscclppCalloc failed"); - goto exit; - } - memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char)); - filePath[strlen(ibDevPath)] = '\0'; - if (strncat(filePath, postfix, strlen(postfix)) == NULL) { - WARN("strncat failed"); - goto exit; - } - fp = fopen(filePath, "r"); - if (fp == NULL) { - WARN("fopen failed (errno %d, path %s)", errno, filePath); - goto exit; - } - res = fscanf(fp, "%d", &node); - if (res != 1) { - WARN("fscanf failed (errno %d, path %s)", errno, filePath); - node = -1; - goto exit; - } -exit: - if (filePath != NULL) { - free(filePath); - } - if (fp != NULL) { - fclose(fp); - } - return node; -} +#define MAXCONNECTIONS 64 -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName) -{ - struct mscclppIbContext* _ctx; - MSCCLPPCHECK(mscclppCalloc(&_ctx, 1)); +namespace mscclpp { - std::vector ports; - - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - for (int i = 0; i < num; ++i) { - if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) { - _ctx->ctx = ibv_open_device(devices[i]); - break; - } - } - ibv_free_device_list(devices); - if (_ctx->ctx == nullptr) { - WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; - } - - // Check available ports - struct ibv_device_attr devAttr; - if (ibv_query_device(_ctx->ctx, &devAttr) != 0) { - WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; - } - - for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) { - struct ibv_port_attr portAttr; - if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, i); - goto fail; - } - if (portAttr.state != IBV_PORT_ACTIVE) { - continue; - } - if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) { - continue; - } - ports.push_back((int)i); - } - if (ports.size() == 0) { - WARN("no active IB port found"); - goto fail; - } - MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size())); - _ctx->nPorts = (int)ports.size(); - for (int i = 0; i < _ctx->nPorts; ++i) { - _ctx->ports[i] = ports[i]; - } - - _ctx->pd = ibv_alloc_pd(_ctx->ctx); - if (_ctx->pd == NULL) { - WARN("ibv_alloc_pd failed (errno %d)", errno); - goto fail; - } - - *ctx = _ctx; - return mscclppSuccess; -fail: - *ctx = NULL; - if (_ctx->ports != NULL) { - free(_ctx->ports); - } - free(_ctx); - return mscclppInternalError; -} - -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx) -{ - for (int i = 0; i < ctx->nMrs; ++i) { - if (ctx->mrs[i].mr) { - ibv_dereg_mr(ctx->mrs[i].mr); - } - } - for (int i = 0; i < ctx->nQps; ++i) { - if (ctx->qps[i].qp) { - ibv_destroy_qp(ctx->qps[i].qp); - } - ibv_destroy_cq(ctx->qps[i].cq); - free(ctx->qps[i].wcs); - free(ctx->qps[i].sges); - free(ctx->qps[i].wrs); - } - if (ctx->pd != NULL) { - ibv_dealloc_pd(ctx->pd); - } - if (ctx->ctx != NULL) { - ibv_close_device(ctx->ctx); - } - free(ctx->mrs); - free(ctx->qps); - free(ctx->ports); - free(ctx); - return mscclppSuccess; -} - -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/) -{ - if (port < 0) { - port = ctx->ports[0]; - } else { - bool found = false; - for (int i = 0; i < ctx->nPorts; ++i) { - if (ctx->ports[i] == port) { - found = true; - break; - } - } - if (!found) { - WARN("invalid IB port: %d", port); - return mscclppInternalError; - } - } - - struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0); - if (cq == NULL) { - WARN("ibv_create_cq failed (errno %d)", errno); - return mscclppInternalError; - } - - struct ibv_qp_init_attr qp_init_attr; - std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr)); - qp_init_attr.sq_sig_all = 0; - qp_init_attr.send_cq = cq; - qp_init_attr.recv_cq = cq; - qp_init_attr.qp_type = IBV_QPT_RC; - qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_send_sge = 1; - qp_init_attr.cap.max_recv_sge = 1; - qp_init_attr.cap.max_inline_data = 0; - struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (qp == nullptr) { - WARN("ibv_create_qp failed (errno %d)", errno); - return mscclppInternalError; - } - struct ibv_port_attr port_attr; - if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, port); - return mscclppInternalError; - } - - // Register QP to this ctx - qp->context = ctx->ctx; - if (qp->context == NULL) { - WARN("IB context is NULL"); - return mscclppInternalError; - } - ctx->nQps++; - if (ctx->qps == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS)); - ctx->maxQps = MAXCONNECTIONS; - } - if (ctx->maxQps < ctx->nQps) { - WARN("too many QPs"); - return mscclppInternalError; - } - struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1]; - _ibQp->qp = qp; - _ibQp->info.lid = port_attr.lid; - _ibQp->info.port = port; - _ibQp->info.linkLayer = port_attr.link_layer; - _ibQp->info.qpn = qp->qp_num; - _ibQp->info.mtu = port_attr.active_mtu; - _ibQp->info.is_grh = (port_attr.flags & IBV_QPF_GRH_REQUIRED); - - if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND || _ibQp->info.is_grh) { - union ibv_gid gid; - if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) { - WARN("ibv_query_gid failed (errno %d)", errno); - return mscclppInternalError; - } - _ibQp->info.spn = gid.global.subnet_prefix; - _ibQp->info.iid = gid.global.interface_id; - } - - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_INIT; - qp_attr.pkey_index = 0; - qp_attr.port_num = port; - qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; - if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - WARN("ibv_modify_qp failed (errno %d)", errno); - return mscclppInternalError; - } - - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM)); - _ibQp->cq = cq; - - *ibQp = _ibQp; - - return mscclppSuccess; -} - -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr) -{ +IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) { if (size == 0) { - WARN("invalid size: %zu", size); - return mscclppInvalidArgument; + throw std::invalid_argument("invalid size: " + std::to_string(size)); } static __thread uintptr_t pageSize = 0; if (pageSize == 0) { pageSize = sysconf(_SC_PAGESIZE); } uintptr_t addr = reinterpret_cast(buff) & -pageSize; - size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - struct ibv_mr* mr = - ibv_reg_mr(ctx->pd, reinterpret_cast(addr), pages * pageSize, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); - if (mr == nullptr) { - WARN("ibv_reg_mr failed (errno %d)", errno); - return mscclppInternalError; + std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + struct ibv_pd* _pd = reinterpret_cast(pd); + struct ibv_mr* _mr = ibv_reg_mr( + _pd, reinterpret_cast(addr), pages * pageSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); + if (_mr == nullptr) { + std::stringstream err; + err << "ibv_reg_mr failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); } - ctx->nMrs++; - if (ctx->mrs == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS)); - ctx->maxMrs = MAXCONNECTIONS; - } - if (ctx->maxMrs < ctx->nMrs) { - WARN("too many MRs"); - return mscclppInternalError; - } - struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1]; - _ibMr->mr = mr; - _ibMr->buff = buff; - _ibMr->info.addr = (uint64_t)buff; - _ibMr->info.rkey = mr->rkey; - *ibMr = _ibMr; - return mscclppSuccess; + this->mr = _mr; + this->size = pages * pageSize; } -////////////////////////////////////////////////////////////////////////////// +IbMr::~IbMr() { ibv_dereg_mr(reinterpret_cast(this->mr)); } -int mscclppIbQp::rtr(const mscclppIbQpInfo* info) -{ +IbMrInfo IbMr::getInfo() const { + IbMrInfo info; + info.addr = reinterpret_cast(this->buff); + info.rkey = reinterpret_cast(this->mr)->rkey; + return info; +} + +const void* IbMr::getBuff() const { return this->buff; } + +uint32_t IbMr::getLkey() const { return reinterpret_cast(this->mr)->lkey; } + +IbQp::IbQp(void* ctx, void* pd, int port) { + struct ibv_context* _ctx = reinterpret_cast(ctx); + struct ibv_pd* _pd = reinterpret_cast(pd); + + this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); + if (this->cq == nullptr) { + std::stringstream err; + err << "ibv_create_cq failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + + struct ibv_qp_init_attr qpInitAttr; + std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); + qpInitAttr.sq_sig_all = 0; + qpInitAttr.send_cq = reinterpret_cast(this->cq); + qpInitAttr.recv_cq = reinterpret_cast(this->cq); + qpInitAttr.qp_type = IBV_QPT_RC; + qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_inline_data = 0; + + struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); + if (_qp == nullptr) { + std::stringstream err; + err << "ibv_create_qp failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + + struct ibv_port_attr portAttr; + if (ibv_query_port(_ctx, port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + this->info.lid = portAttr.lid; + this->info.port = port; + this->info.linkLayer = portAttr.link_layer; + this->info.qpn = _qp->qp_num; + this->info.mtu = portAttr.active_mtu; + this->info.is_grh = (portAttr.flags & IBV_QPF_GRH_REQUIRED); + + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND || this->info.is_grh) { + union ibv_gid gid; + if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { + std::stringstream err; + err << "ibv_query_gid failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + this->info.spn = gid.global.subnet_prefix; + this->info.iid = gid.global.interface_id; + } + + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(qpAttr)); + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = 0; + qpAttr.port_num = port; + qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + this->qp = _qp; + this->wrn = 0; + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wrs), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->sges), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM)); +} + +IbQp::~IbQp() { + ibv_destroy_qp(reinterpret_cast(this->qp)); + ibv_destroy_cq(reinterpret_cast(this->cq)); + std::free(this->wrs); + std::free(this->sges); + std::free(this->wcs); +} + +void IbQp::rtr(const IbQpInfo& info) { struct ibv_qp_attr qp_attr; std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = info->mtu; - qp_attr.dest_qp_num = info->qpn; + qp_attr.path_mtu = static_cast(info.mtu); + qp_attr.dest_qp_num = info.qpn; qp_attr.rq_psn = 0; qp_attr.max_dest_rd_atomic = 1; qp_attr.min_rnr_timer = 0x12; - if (info->linkLayer == IBV_LINK_LAYER_ETHERNET || info->is_grh) { + if (info.linkLayer == IBV_LINK_LAYER_ETHERNET || info.is_grh) { qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info->iid; + qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qp_attr.ah_attr.grh.dgid.global.interface_id = info.iid; qp_attr.ah_attr.grh.flow_label = 0; qp_attr.ah_attr.grh.sgid_index = 0; qp_attr.ah_attr.grh.hop_limit = 255; @@ -312,17 +156,21 @@ int mscclppIbQp::rtr(const mscclppIbQpInfo* info) } else { qp_attr.ah_attr.is_global = 0; } - qp_attr.ah_attr.dlid = info->lid; + qp_attr.ah_attr.dlid = info.lid; qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info->port; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + qp_attr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } -int mscclppIbQp::rts() -{ +void IbQp::rts() { struct ibv_qp_attr qp_attr; std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_RTS; @@ -331,75 +179,244 @@ int mscclppIbQp::rts() qp_attr.rnr_retry = 7; qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC); + int ret = ibv_modify_qp( + reinterpret_cast(this->qp), &qp_attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } -int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled) -{ +int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled) { if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { return -1; } int wrn = this->wrn; - struct ibv_send_wr* wr_ = &this->wrs[wrn]; - struct ibv_sge* sge_ = &this->sges[wrn]; - // std::memset(wr_, 0, sizeof(struct ibv_send_wr)); - // std::memset(sge_, 0, sizeof(struct ibv_sge)); + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + struct ibv_sge* sges_ = reinterpret_cast(this->sges); + + struct ibv_send_wr* wr_ = &wrs_[wrn]; + struct ibv_sge* sge_ = &sges_[wrn]; wr_->wr_id = wrId; wr_->sg_list = sge_; wr_->num_sge = 1; wr_->opcode = IBV_WR_RDMA_WRITE; wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset; - wr_->wr.rdma.rkey = info->rkey; + wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset; + wr_->wr.rdma.rkey = info.rkey; wr_->next = nullptr; - sge_->addr = (uint64_t)(ibMr->buff) + srcOffset; + sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset; sge_->length = size; - sge_->lkey = ibMr->mr->lkey; + sge_->lkey = mr->getLkey(); if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; + wrs_[wrn - 1].next = wr_; } this->wrn++; return this->wrn; } -int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) -{ - int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled); - this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - this->wrs[wrn - 1].imm_data = immData; +int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled, unsigned int immData) { + int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled); + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs_[wrn - 1].imm_data = immData; return wrn; } -int mscclppIbQp::postSend() -{ +void IbQp::postSend() { if (this->wrn == 0) { - return 0; + return; } - struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(this->qp, this->wrs, &bad_wr); + int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), + &bad_wr); if (ret != 0) { - return ret; + std::stringstream err; + err << "ibv_post_send failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); } this->wrn = 0; - return 0; } -int mscclppIbQp::postRecv(uint64_t wrId) -{ +void IbQp::postRecv(uint64_t wrId) { struct ibv_recv_wr wr, *bad_wr; wr.wr_id = wrId; wr.sg_list = nullptr; wr.num_sge = 0; wr.next = nullptr; - return ibv_post_recv(this->qp, &wr, &bad_wr); + int ret = ibv_post_recv(reinterpret_cast(this->qp), &wr, &bad_wr); + if (ret != 0) { + std::stringstream err; + err << "ibv_post_recv failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } -int mscclppIbQp::pollCq() -{ - return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs); +int IbQp::pollCq() { + return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, + reinterpret_cast(this->wcs)); } + +IbQpInfo& IbQp::getInfo() { return this->info; } + +const void* IbQp::getWc(int idx) const { return &reinterpret_cast(this->wcs)[idx]; } + +IbCtx::IbCtx(const std::string& devName) : devName(devName) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (std::string(devices[i]->name) == devName) { + this->ctx = ibv_open_device(devices[i]); + break; + } + } + ibv_free_device_list(devices); + if (this->ctx == nullptr) { + std::stringstream err; + err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; + throw mscclpp::IbError(err.str(), errno); + } + this->pd = ibv_alloc_pd(reinterpret_cast(this->ctx)); + if (this->pd == nullptr) { + std::stringstream err; + err << "ibv_alloc_pd failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } +} + +IbCtx::~IbCtx() { + this->mrs.clear(); + this->qps.clear(); + if (this->pd != nullptr) { + ibv_dealloc_pd(reinterpret_cast(this->pd)); + } + if (this->ctx != nullptr) { + ibv_close_device(reinterpret_cast(this->ctx)); + } +} + +bool IbCtx::isPortUsable(int port) const { + struct ibv_port_attr portAttr; + if (ibv_query_port(reinterpret_cast(this->ctx), port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; + throw mscclpp::IbError(err.str(), errno); + } + return portAttr.state == IBV_PORT_ACTIVE && + (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); +} + +int IbCtx::getAnyActivePort() const { + struct ibv_device_attr devAttr; + if (ibv_query_device(reinterpret_cast(this->ctx), &devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + if (this->isPortUsable(port)) { + return port; + } + } + return -1; +} + +IbQp* IbCtx::createQp(int port /*=-1*/) { + if (port == -1) { + port = this->getAnyActivePort(); + if (port == -1) { + throw mscclpp::Error("No active port found", ErrorCode::InternalError); + } + } else if (!this->isPortUsable(port)) { + throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); + } + qps.emplace_back(new IbQp(this->ctx, this->pd, port)); + return qps.back().get(); +} + +const IbMr* IbCtx::registerMr(void* buff, std::size_t size) { + mrs.emplace_back(new IbMr(this->pd, buff, size)); + return mrs.back().get(); +} + +const std::string& IbCtx::getDevName() const { return this->devName; } + +MSCCLPP_API_CPP int getIBDeviceCount() { + int num; + ibv_get_device_list(&num); + return num; +} + +MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + int ibTransportIndex; + switch (ibTransport) { // TODO: get rid of this ugly switch + case Transport::IB0: + ibTransportIndex = 0; + break; + case Transport::IB1: + ibTransportIndex = 1; + break; + case Transport::IB2: + ibTransportIndex = 2; + break; + case Transport::IB3: + ibTransportIndex = 3; + break; + case Transport::IB4: + ibTransportIndex = 4; + break; + case Transport::IB5: + ibTransportIndex = 5; + break; + case Transport::IB6: + ibTransportIndex = 6; + break; + case Transport::IB7: + ibTransportIndex = 7; + break; + default: + throw std::invalid_argument("Not an IB transport"); + } + if (ibTransportIndex >= num) { + throw std::out_of_range("IB transport out of range"); + } + return devices[ibTransportIndex]->name; +} + +MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (ibDeviceName == devices[i]->name) { + switch (i) { // TODO: get rid of this ugly switch + case 0: + return Transport::IB0; + case 1: + return Transport::IB1; + case 2: + return Transport::IB2; + case 3: + return Transport::IB3; + case 4: + return Transport::IB4; + case 5: + return Transport::IB5; + case 6: + return Transport::IB6; + case 7: + return Transport::IB7; + default: + throw std::out_of_range("IB device index out of range"); + } + } + } + throw std::invalid_argument("IB device not found"); +} + +} // namespace mscclpp diff --git a/src/include/align.h b/src/include/align.h index 008d2b44..981d943d 100644 --- a/src/include/align.h +++ b/src/include/align.h @@ -22,19 +22,19 @@ #endif #endif -template __host__ __device__ constexpr Z divUp(X x, Y y) -{ +template +__host__ __device__ constexpr Z divUp(X x, Y y) { return (x + y - 1) / y; } -template __host__ __device__ constexpr Z roundUp(X x, Y y) -{ +template +__host__ __device__ constexpr Z roundUp(X x, Y y) { return (x + y - 1) - (x + y - 1) % y; } // assumes second argument is a power of 2 -template __host__ __device__ constexpr Z alignUp(X x, int a) -{ +template +__host__ __device__ constexpr Z alignUp(X x, int a) { return (x + a - 1) & Z(-a); } diff --git a/src/include/alloc.h b/src/include/alloc.h index 496af197..5de23e87 100644 --- a/src/include/alloc.h +++ b/src/include/alloc.h @@ -7,17 +7,17 @@ #ifndef MSCCLPP_ALLOC_H_ #define MSCCLPP_ALLOC_H_ +#include +#include +#include + #include "align.h" #include "checks.h" #include "mscclpp.h" #include "utils.h" -#include -#include -#include -#include -template mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) -{ +template +mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; *ptr = nullptr; @@ -27,21 +27,19 @@ template mscclppResult_t mscclppCudaHostCallocDebug(T** ptr, size_t memset(*ptr, 0, nelem * sizeof(T)); finish: CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (*ptr == nullptr) - WARN("Failed to CUDA host alloc %ld bytes", nelem * sizeof(T)); + if (*ptr == nullptr) WARN("Failed to CUDA host alloc %ld bytes", nelem * sizeof(T)); INFO(MSCCLPP_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr); return result; } #define mscclppCudaHostCalloc(...) mscclppCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__) -inline mscclppResult_t mscclppCudaHostFree(void* ptr) -{ +inline mscclppResult_t mscclppCudaHostFree(void* ptr) { CUDACHECK(cudaFreeHost(ptr)); return mscclppSuccess; } -template mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) -{ +template +mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) { void* p = malloc(nelem * sizeof(T)); if (p == NULL) { WARN("Failed to malloc %ld bytes", nelem * sizeof(T)); @@ -54,12 +52,10 @@ template mscclppResult_t mscclppCallocDebug(T** ptr, size_t nelem, } #define mscclppCalloc(...) mscclppCallocDebug(__VA_ARGS__, __FILE__, __LINE__) -template mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, size_t nelem) -{ - if (nelem < oldNelem) - return mscclppInternalError; - if (nelem == oldNelem) - return mscclppSuccess; +template +mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, size_t nelem) { + if (nelem < oldNelem) return mscclppInternalError; + if (nelem == oldNelem) return mscclppSuccess; T* oldp = *ptr; T* p = (T*)malloc(nelem * sizeof(T)); @@ -76,8 +72,8 @@ template mscclppResult_t mscclppRealloc(T** ptr, size_t oldNelem, s return mscclppSuccess; } -template mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) -{ +template +mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; *ptr = nullptr; @@ -85,15 +81,14 @@ template mscclppResult_t mscclppCudaMallocDebug(T** ptr, size_t nel CUDACHECKGOTO(cudaMalloc(ptr, nelem * sizeof(T)), result, finish); finish: CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (*ptr == nullptr) - WARN("Failed to CUDA malloc %ld bytes", nelem * sizeof(T)); + if (*ptr == nullptr) WARN("Failed to CUDA malloc %ld bytes", nelem * sizeof(T)); INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr); return result; } #define mscclppCudaMalloc(...) mscclppCudaMallocDebug(__VA_ARGS__, __FILE__, __LINE__) -template mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) -{ +template +mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; *ptr = nullptr; @@ -107,16 +102,15 @@ template mscclppResult_t mscclppCudaCallocDebug(T** ptr, size_t nel CUDACHECKGOTO(cudaStreamDestroy(stream), result, finish); finish: CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (*ptr == nullptr) - WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T)); + if (*ptr == nullptr) WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T)); INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr); return result; } #define mscclppCudaCalloc(...) mscclppCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__) template -mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t stream, const char* filefunc, int line) -{ +mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t stream, const char* filefunc, + int line) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; *ptr = nullptr; @@ -125,15 +119,14 @@ mscclppResult_t mscclppCudaCallocAsyncDebug(T** ptr, size_t nelem, cudaStream_t CUDACHECKGOTO(cudaMemsetAsync(*ptr, 0, nelem * sizeof(T), stream), result, finish); finish: CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (*ptr == nullptr) - WARN("Failed to CUDA calloc async %ld bytes", nelem * sizeof(T)); + if (*ptr == nullptr) WARN("Failed to CUDA calloc async %ld bytes", nelem * sizeof(T)); INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr); return result; } #define mscclppCudaCallocAsync(...) mscclppCudaCallocAsyncDebug(__VA_ARGS__, __FILE__, __LINE__) -template mscclppResult_t mscclppCudaMemcpy(T* dst, T* src, size_t nelem) -{ +template +mscclppResult_t mscclppCudaMemcpy(T* dst, T* src, size_t nelem) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); @@ -148,8 +141,8 @@ finish: return result; } -template mscclppResult_t mscclppCudaMemcpyAsync(T* dst, T* src, size_t nelem, cudaStream_t stream) -{ +template +mscclppResult_t mscclppCudaMemcpyAsync(T* dst, T* src, size_t nelem, cudaStream_t stream) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); @@ -159,8 +152,8 @@ finish: return result; } -template mscclppResult_t mscclppCudaFree(T* ptr) -{ +template +mscclppResult_t mscclppCudaFree(T* ptr) { mscclppResult_t result = mscclppSuccess; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); @@ -173,14 +166,12 @@ finish: // Allocate memory to be potentially ibv_reg_mr'd. This needs to be // allocated on separate pages as those pages will be marked DONTFORK // and if they are shared, that could cause a crash in a child process -inline mscclppResult_t mscclppIbMallocDebug(void** ptr, size_t size, const char* filefunc, int line) -{ +inline mscclppResult_t mscclppIbMallocDebug(void** ptr, size_t size, const char* filefunc, int line) { size_t page_size = sysconf(_SC_PAGESIZE); void* p; int size_aligned = ROUNDUP(size, page_size); int ret = posix_memalign(&p, page_size, size_aligned); - if (ret != 0) - return mscclppSystemError; + if (ret != 0) return mscclppSystemError; memset(p, 0, size); *ptr = p; INFO(MSCCLPP_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr); diff --git a/src/include/api.h b/src/include/api.h new file mode 100644 index 00000000..cb2cac81 --- /dev/null +++ b/src/include/api.h @@ -0,0 +1,7 @@ +#ifndef MSCCLPP_API_H_ +#define MSCCLPP_API_H_ + +#define MSCCLPP_API extern "C" __attribute__((visibility("default"))) +#define MSCCLPP_API_CPP __attribute__((visibility("default"))) + +#endif // MSCCLPP_API_H_ diff --git a/src/include/basic_proxy_handler.hpp b/src/include/basic_proxy_handler.hpp new file mode 100644 index 00000000..2d22a309 --- /dev/null +++ b/src/include/basic_proxy_handler.hpp @@ -0,0 +1,14 @@ +#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_ +#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_ + +#include + +#include "communicator.hpp" + +namespace mscclpp { + +ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm); + +} + +#endif \ No newline at end of file diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h deleted file mode 100644 index 95320b07..00000000 --- a/src/include/bootstrap.h +++ /dev/null @@ -1,35 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef MSCCLPP_BOOTSTRAP_H_ -#define MSCCLPP_BOOTSTRAP_H_ - -#include "mscclpp.h" -#include "socket.h" - -#include "comm.h" - -struct mscclppBootstrapHandle -{ - uint64_t magic; - union mscclppSocketAddress addr; -}; -static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), - "Bootstrap handle is too large to fit inside MSCCLPP unique ID"); - -mscclppResult_t bootstrapNetInit(const char* ip_port_pair = NULL); -mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle); -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, - const char* ip_port_pair = NULL); -mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm); -mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size); -mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); -mscclppResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); -mscclppResult_t bootstrapBarrier(void* commState, int* ranks, int rank, int nranks, int tag); -mscclppResult_t bootstrapIntraNodeAllGather(void* commState, int* ranks, int rank, int nranks, void* allData, int size); -mscclppResult_t bootstrapClose(void* commState); -mscclppResult_t bootstrapAbort(void* commState); -#endif diff --git a/src/include/checks.h b/src/include/checks.h index f93945c7..c877cdea 100644 --- a/src/include/checks.h +++ b/src/include/checks.h @@ -7,178 +7,182 @@ #ifndef MSCCLPP_CHECKS_H_ #define MSCCLPP_CHECKS_H_ -#include "debug.h" #include +#include "debug.h" + // Check CUDA RT calls -#define CUDACHECK(cmd) \ - do { \ - cudaError_t err = cmd; \ - if (err != cudaSuccess) { \ - WARN("Cuda failure '%s'", cudaGetErrorString(err)); \ - return mscclppUnhandledCudaError; \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + WARN("Cuda failure '%s'", cudaGetErrorString(err)); \ + return mscclppUnhandledCudaError; \ + } \ } while (false) -#define CUDACHECKGOTO(cmd, res, label) \ - do { \ - cudaError_t err = cmd; \ - if (err != cudaSuccess) { \ - WARN("Cuda failure '%s'", cudaGetErrorString(err)); \ - res = mscclppUnhandledCudaError; \ - goto label; \ - } \ +#define CUDACHECKNORET(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + WARN("Cuda failure '%s'", cudaGetErrorString(err)); \ + return; \ + } \ + } while (false) + +#define CUDACHECKGOTO(cmd, res, label) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + WARN("Cuda failure '%s'", cudaGetErrorString(err)); \ + res = mscclppUnhandledCudaError; \ + goto label; \ + } \ } while (false) // Report failure but clear error and continue -#define CUDACHECKIGNORE(cmd) \ - do { \ - cudaError_t err = cmd; \ - if (err != cudaSuccess) { \ - INFO(MSCCLPP_ALL, "%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \ - (void)cudaGetLastError(); \ - } \ +#define CUDACHECKIGNORE(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + INFO(MSCCLPP_ALL, "%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \ + (void)cudaGetLastError(); \ + } \ } while (false) #include // Check system calls -#define SYSCHECK(call, name) \ - do { \ - int retval; \ - SYSCHECKVAL(call, name, retval); \ +#define SYSCHECK(call, name) \ + do { \ + int retval; \ + SYSCHECKVAL(call, name, retval); \ } while (false) -#define SYSCHECKVAL(call, name, retval) \ - do { \ - SYSCHECKSYNC(call, name, retval); \ - if (retval == -1) { \ - WARN("Call to " name " failed : %s", strerror(errno)); \ - return mscclppSystemError; \ - } \ +#define SYSCHECKVAL(call, name, retval) \ + do { \ + SYSCHECKSYNC(call, name, retval); \ + if (retval == -1) { \ + WARN("Call to " name " failed : %s", strerror(errno)); \ + return mscclppSystemError; \ + } \ } while (false) -#define SYSCHECKSYNC(call, name, retval) \ - do { \ - retval = call; \ - if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ - INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \ - } else { \ - break; \ - } \ +#define SYSCHECKSYNC(call, name, retval) \ + do { \ + retval = call; \ + if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ + INFO(MSCCLPP_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \ + } else { \ + break; \ + } \ } while (true) -#define SYSCHECKGOTO(statement, res, label) \ - do { \ - if ((statement) == -1) { \ - /* Print the back trace*/ \ - res = mscclppSystemError; \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - goto label; \ - } \ +#define SYSCHECKGOTO(statement, res, label) \ + do { \ + if ((statement) == -1) { \ + /* Print the back trace*/ \ + res = mscclppSystemError; \ + INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + goto label; \ + } \ } while (0); -#define NEQCHECK(statement, value) \ - do { \ - if ((statement) != value) { \ - /* Print the back trace*/ \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \ - return mscclppSystemError; \ - } \ +#define NEQCHECK(statement, value) \ + do { \ + if ((statement) != value) { \ + /* Print the back trace*/ \ + INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \ + return mscclppSystemError; \ + } \ } while (0); -#define NEQCHECKGOTO(statement, value, res, label) \ - do { \ - if ((statement) != value) { \ - /* Print the back trace*/ \ - res = mscclppSystemError; \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - goto label; \ - } \ +#define NEQCHECKGOTO(statement, value, res, label) \ + do { \ + if ((statement) != value) { \ + /* Print the back trace*/ \ + res = mscclppSystemError; \ + INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + goto label; \ + } \ } while (0); -#define EQCHECK(statement, value) \ - do { \ - if ((statement) == value) { \ - /* Print the back trace*/ \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \ - return mscclppSystemError; \ - } \ +#define EQCHECK(statement, value) \ + do { \ + if ((statement) == value) { \ + /* Print the back trace*/ \ + INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, mscclppSystemError); \ + return mscclppSystemError; \ + } \ } while (0); -#define EQCHECKGOTO(statement, value, res, label) \ - do { \ - if ((statement) == value) { \ - /* Print the back trace*/ \ - res = mscclppSystemError; \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - goto label; \ - } \ +#define EQCHECKGOTO(statement, value, res, label) \ + do { \ + if ((statement) == value) { \ + /* Print the back trace*/ \ + res = mscclppSystemError; \ + INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + goto label; \ + } \ } while (0); // Propagate errors up -#define MSCCLPPCHECK(call) \ - do { \ - mscclppResult_t res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - /* Print the back trace*/ \ - if (mscclppDebugNoWarn == 0) \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - return res; \ - } \ +#define MSCCLPPCHECK(call) \ + do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + return res; \ + } \ } while (0); -#define MSCCLPPCHECKGOTO(call, res, label) \ - do { \ - res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - /* Print the back trace*/ \ - if (mscclppDebugNoWarn == 0) \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - goto label; \ - } \ +#define MSCCLPPCHECKGOTO(call, res, label) \ + do { \ + res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + goto label; \ + } \ } while (0); -#define MSCCLPPWAIT(call, cond, abortFlagPtr) \ - do { \ - volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \ - mscclppResult_t res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - if (mscclppDebugNoWarn == 0) \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - return mscclppInternalError; \ - } \ - if (tmpAbortFlag) \ - NEQCHECK(*tmpAbortFlag, 0); \ +#define MSCCLPPWAIT(call, cond, abortFlagPtr) \ + do { \ + volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + return mscclppInternalError; \ + } \ + if (tmpAbortFlag) NEQCHECK(*tmpAbortFlag, 0); \ } while (!(cond)); -#define MSCCLPPWAITGOTO(call, cond, abortFlagPtr, res, label) \ - do { \ - volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \ - res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - if (mscclppDebugNoWarn == 0) \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - goto label; \ - } \ - if (tmpAbortFlag) \ - NEQCHECKGOTO(*tmpAbortFlag, 0, res, label); \ +#define MSCCLPPWAITGOTO(call, cond, abortFlagPtr, res, label) \ + do { \ + volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \ + res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + if (mscclppDebugNoWarn == 0) INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ + goto label; \ + } \ + if (tmpAbortFlag) NEQCHECKGOTO(*tmpAbortFlag, 0, res, label); \ } while (!(cond)); -#define MSCCLPPCHECKTHREAD(a, args) \ - do { \ - if (((args)->ret = (a)) != mscclppSuccess && (args)->ret != mscclppInProgress) { \ - INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, (args)->ret); \ - return args; \ - } \ +#define MSCCLPPCHECKTHREAD(a, args) \ + do { \ + if (((args)->ret = (a)) != mscclppSuccess && (args)->ret != mscclppInProgress) { \ + INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, (args)->ret); \ + return args; \ + } \ } while (0) -#define CUDACHECKTHREAD(a) \ - do { \ - if ((a) != cudaSuccess) { \ - INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, args->ret); \ - args->ret = mscclppUnhandledCudaError; \ - return args; \ - } \ +#define CUDACHECKTHREAD(a) \ + do { \ + if ((a) != cudaSuccess) { \ + INFO(MSCCLPP_INIT, "%s:%d -> %d [Async thread]", __FILE__, __LINE__, args->ret); \ + args->ret = mscclppUnhandledCudaError; \ + return args; \ + } \ } while (0) #endif diff --git a/src/include/checks.hpp b/src/include/checks.hpp new file mode 100644 index 00000000..00acc2f3 --- /dev/null +++ b/src/include/checks.hpp @@ -0,0 +1,44 @@ +#ifndef MSCCLPP_CHECKS_HPP_ +#define MSCCLPP_CHECKS_HPP_ + +#include +#include + +#include + +#include "debug.h" + +#define MSCCLPPTHROW(call) \ + do { \ + mscclppResult_t res = call; \ + mscclpp::ErrorCode err = mscclpp::ErrorCode::InternalError; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + if (res == mscclppInvalidUsage) { \ + err = mscclpp::ErrorCode::InvalidUsage; \ + } else if (res == mscclppSystemError) { \ + err = mscclpp::ErrorCode::SystemError; \ + } \ + throw mscclpp::Error(std::string("Call to " #call " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \ + err); \ + } \ + } while (false) + +#define CUDATHROW(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + throw mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \ + err); \ + } \ + } while (false) + +#define CUTHROW(cmd) \ + do { \ + CUresult err = cmd; \ + if (err != CUDA_SUCCESS) { \ + throw mscclpp::CuError(std::string("Call to " #cmd " failed.") + __FILE__ + ":" + std::to_string(__LINE__), \ + err); \ + } \ + } while (false) + +#endif diff --git a/src/include/comm.h b/src/include/comm.h deleted file mode 100644 index b45f4348..00000000 --- a/src/include/comm.h +++ /dev/null @@ -1,72 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef MSCCLPP_COMM_H_ -#define MSCCLPP_COMM_H_ - -#include "ib.h" -#include "proxy.h" - -#include - -// #define CACHE_LINE_SIZE 128 -// #define MEM_ALIGN 4096 -// #define CUDA_IPC_MIN 2097152UL - -// // Channels / LL tuning -// #define MSCCLPP_LL_THREAD_THRESHOLD 8 -// #define MSCCLPP_LL128_THREAD_THRESHOLD 8 -// #define MSCCLPP_SIMPLE_THREAD_THRESHOLD 64 - -#define MAXCONNECTIONS 64 - -struct mscclppConn -{ - mscclppTransport_t transport; - int remoteRank; - uint64_t buffSize; - uint64_t* remoteProxyFlag; - uint64_t* cpuProxyFlag; - void* cpuProxyFlagGdrDesc; - struct mscclppDevConn* devConn; - struct mscclppIbContext* ibCtx; - struct mscclppIbQp* ibQp; - struct mscclppIbMr* ibBuffMr; - struct mscclppIbMr* ibLocalFlagMr; - struct mscclppIbMr* ibProxyFlagMr; - struct mscclppIbMrInfo ibBuffMrInfo; - struct mscclppIbMrInfo ibLocalFlagMrInfo; - struct mscclppIbMrInfo ibProxyFlagMrInfo; -#if defined(ENABLE_NPKIT) - std::vector npkitUsedReqIds; - std::vector npkitFreeReqIds; -#endif -}; - -struct mscclppComm -{ - struct mscclppConn conns[MAXCONNECTIONS]; - struct mscclppDevConn devConns[MAXCONNECTIONS]; - int nConns; - - void* bootstrap; - - uint64_t - magic; // Magic number for all network communication. Not a security key -- only goal is to detect mismatches. - - int rank; // my rank in the communicator - int nRanks; // number of GPUs in communicator - int cudaDev; // my cuda device index - int devNumaNode; // my device's NUMA node - - // Flag to ask MSCCLPP kernels to abort - volatile uint32_t* abortFlag; - - struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS]; - struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM]; -}; - -#endif diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp new file mode 100644 index 00000000..6461eb13 --- /dev/null +++ b/src/include/communicator.hpp @@ -0,0 +1,36 @@ +#ifndef MSCCL_COMMUNICATOR_HPP_ +#define MSCCL_COMMUNICATOR_HPP_ + +#include + +#include +#include +#include +#include + +#include "ib.hpp" +#include "mscclpp.h" + +namespace mscclpp { + +class ConnectionBase; + +struct Communicator::Impl { + std::vector> connections_; + std::vector> toSetup_; + std::unordered_map> ibContexts_; + cudaStream_t ipcStream_; + std::shared_ptr bootstrap_; + std::vector rankToHash_; + + Impl(std::shared_ptr bootstrap); + + ~Impl(); + + IbCtx* getIbContext(Transport ibTransport); + cudaStream_t getIpcStream(); +}; + +} // namespace mscclpp + +#endif // MSCCL_COMMUNICATOR_HPP_ diff --git a/src/include/config.h b/src/include/config.h index 49f8cb57..60fe3e3e 100644 --- a/src/include/config.h +++ b/src/include/config.h @@ -3,16 +3,15 @@ #include -class mscclppConfig -{ -public: +class mscclppConfig { + public: time_t bootstrapConnectionTimeout = 30; static mscclppConfig* getInstance(); time_t getBootstrapConnectionTimeoutConfig(); void setBootstrapConnectionTimeoutConfig(time_t timeout); -private: + private: mscclppConfig() = default; mscclppConfig(const mscclppConfig&) = delete; mscclppConfig& operator=(const mscclppConfig&) = delete; @@ -20,4 +19,4 @@ private: static mscclppConfig _instance; }; -#endif // end include guard +#endif // end include guard diff --git a/src/include/connection.hpp b/src/include/connection.hpp new file mode 100644 index 00000000..3e9896ba --- /dev/null +++ b/src/include/connection.hpp @@ -0,0 +1,72 @@ +#ifndef MSCCLPP_CONNECTION_HPP_ +#define MSCCLPP_CONNECTION_HPP_ + +// TODO(saemal): make this configurable +#define MSCCLPP_POLLING_WAIT 3e7 // in microseconds + +#include + +#include + +#include "communicator.hpp" +#include "ib.hpp" + +namespace mscclpp { + +// TODO: Add functionality to these classes for Communicator to do connectionSetup + +class ConnectionBase : public Connection, public Setuppable { + int remoteRank_; + int tag_; + + public: + ConnectionBase(int remoteRank, int tag); + + int remoteRank() override; + int tag() override; +}; + +class CudaIpcConnection : public ConnectionBase { + cudaStream_t stream_; + + public: + CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream); + + ~CudaIpcConnection(); + + Transport transport() override; + + Transport remoteTransport() override; + + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; + + void flush() override; +}; + +class IBConnection : public ConnectionBase { + Transport transport_; + Transport remoteTransport_; + IbQp* qp; + int numSignaledSends; + + public: + IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); + + Transport transport() override; + + Transport remoteTransport() override; + + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; + + void flush() override; + + void beginSetup(std::shared_ptr bootstrap) override; + + void endSetup(std::shared_ptr bootstrap) override; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_CONNECTION_HPP_ diff --git a/src/include/core.h b/src/include/core.h deleted file mode 100644 index e3213bd6..00000000 --- a/src/include/core.h +++ /dev/null @@ -1,30 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef MSCCLPP_CORE_H_ -#define MSCCLPP_CORE_H_ - -#include "alloc.h" -#include "debug.h" -#include "mscclpp.h" -#include "param.h" -#include // For std::min/std::max -#include -#include -#include -#include -#include -#include - -#ifdef PROFAPI -#define MSCCLPP_API(ret, func, args...) \ - __attribute__((visibility("default"))) __attribute__((alias(#func))) ret p##func(args); \ - extern "C" __attribute__((visibility("default"))) __attribute__((weak)) ret func(args) -#else -#define MSCCLPP_API(ret, func, args...) extern "C" __attribute__((visibility("default"))) ret func(args) -#endif // end PROFAPI - -#endif // end include guard diff --git a/src/include/debug.h b/src/include/debug.h index dd548cbb..64b37297 100644 --- a/src/include/debug.h +++ b/src/include/debug.h @@ -7,20 +7,20 @@ #ifndef MSCCLPP_DEBUG_H_ #define MSCCLPP_DEBUG_H_ -#include "mscclpp.h" -#include -#include -#include - #include #include +#include #include +#include +#include + +#include "mscclpp.h" + // Conform to pthread and NVTX standard #define MSCCLPP_THREAD_NAMELEN 16 -typedef enum -{ +typedef enum { MSCCLPP_LOG_NONE = 0, MSCCLPP_LOG_VERSION = 1, MSCCLPP_LOG_WARN = 2, @@ -28,8 +28,7 @@ typedef enum MSCCLPP_LOG_ABORT = 4, MSCCLPP_LOG_TRACE = 5 } mscclppDebugLogLevel; -typedef enum -{ +typedef enum { MSCCLPP_INIT = 1, MSCCLPP_COLL = 2, MSCCLPP_P2P = 4, diff --git a/src/include/gdr.h b/src/include/gdr.h deleted file mode 100644 index d7e0269a..00000000 --- a/src/include/gdr.h +++ /dev/null @@ -1,156 +0,0 @@ -#ifndef MSCCLPP_GDR_H_ -#define MSCCLPP_GDR_H_ - -#include "align.h" -#include "alloc.h" -#include "checks.h" -#include "debug.h" -#include "gdrapi.h" - -// These can be used if the GDR library isn't thread safe -#include -extern pthread_mutex_t gdrLock; -#define GDRLOCK() pthread_mutex_lock(&gdrLock) -#define GDRUNLOCK() pthread_mutex_unlock(&gdrLock) -#define GDRLOCKCALL(cmd, ret) \ - do { \ - GDRLOCK(); \ - ret = cmd; \ - GDRUNLOCK(); \ - } while (false) - -#define GDRCHECK(cmd) \ - do { \ - int e; \ - /* GDRLOCKCALL(cmd, e); */ \ - e = cmd; \ - if (e != 0) { \ - WARN("GDRCOPY failure %d", e); \ - return mscclppSystemError; \ - } \ - } while (false) - -gdr_t wrap_gdr_open(void); -mscclppResult_t wrap_gdr_close(gdr_t g); -mscclppResult_t wrap_gdr_pin_buffer(gdr_t g, unsigned long addr, size_t size, uint64_t p2p_token, uint32_t va_space, - gdr_mh_t* handle); -mscclppResult_t wrap_gdr_unpin_buffer(gdr_t g, gdr_mh_t handle); -mscclppResult_t wrap_gdr_get_info(gdr_t g, gdr_mh_t handle, gdr_info_t* info); -mscclppResult_t wrap_gdr_map(gdr_t g, gdr_mh_t handle, void** va, size_t size); -mscclppResult_t wrap_gdr_unmap(gdr_t g, gdr_mh_t handle, void* va, size_t size); - -// Global GDR driver handle -extern gdr_t mscclppGdrCopy; - -typedef struct gdr_mem_desc -{ - void* gdrDevMem; - void* gdrMap; - size_t gdrOffset; - size_t gdrMapSize; - gdr_mh_t gdrMh; -} gdr_mem_desc_t; - -static gdr_t mscclppGdrInit() -{ - // int libMajor, libMinor, drvMajor, drvMinor; - gdr_t handle = wrap_gdr_open(); - - // if (handle != NULL) { - // mscclppResult_t res; - - // // Query the version of libgdrapi - // MSCCLPPCHECKGOTO(wrap_gdr_runtime_get_version(&libMajor, &libMinor), res, error); - - // // Query the version of gdrdrv driver - // MSCCLPPCHECKGOTO(wrap_gdr_driver_get_version(handle, &drvMajor, &drvMinor), res, error); - - // // Only support GDRAPI 2.1 and later - // if (libMajor < 2 || (libMajor == 2 && libMinor < 1) || drvMajor < 2 || (drvMajor == 2 && drvMinor < 1)) { - // goto error; - // } - // else - // INFO(MSCCLPP_INIT, "GDRCOPY enabled library %d.%d driver %d.%d", libMajor, libMinor, drvMajor, drvMinor); - // } - return handle; - // error: - // if (handle != NULL) (void) wrap_gdr_close(handle); - // return NULL; -} - -template -mscclppResult_t mscclppGdrCudaCallocDebug(T** ptr, T** devPtr, size_t nelem, void** gdrDesc, const char* filefunc, - int line) -{ - mscclppResult_t result = mscclppSuccess; - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - *ptr = nullptr; - *devPtr = nullptr; - *gdrDesc = nullptr; - CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - - gdr_info_t info; - size_t mapSize; - gdr_mh_t mh; - char* devMem; - void* gdrMap; - ssize_t off; - gdr_mem_desc_t* md; - uint64_t alignedAddr; - size_t align; - - mapSize = sizeof(T) * nelem; - - // GDRCOPY Pinned buffer has to be a minimum of a GPU_PAGE_SIZE - ALIGN_SIZE(mapSize, GPU_PAGE_SIZE); - // GDRCOPY Pinned buffer has to be GPU_PAGE_SIZE aligned too - MSCCLPPCHECKGOTO(mscclppCudaCalloc(&devMem, mapSize + GPU_PAGE_SIZE - 1), result, finish); - alignedAddr = (((uint64_t)devMem) + GPU_PAGE_OFFSET) & GPU_PAGE_MASK; - align = alignedAddr - (uint64_t)devMem; - MSCCLPPCHECKGOTO(wrap_gdr_pin_buffer(mscclppGdrCopy, alignedAddr, mapSize, 0, 0, &mh), result, finish); - - MSCCLPPCHECKGOTO(wrap_gdr_map(mscclppGdrCopy, mh, &gdrMap, mapSize), result, finish); - - MSCCLPPCHECKGOTO(wrap_gdr_get_info(mscclppGdrCopy, mh, &info), result, finish); - - // Will offset ever be non zero ? - off = info.va - alignedAddr; - - MSCCLPPCHECKGOTO(mscclppCalloc(&md, 1), result, finish); - md->gdrDevMem = devMem; - md->gdrMap = gdrMap; - md->gdrMapSize = mapSize; - md->gdrOffset = off + align; - md->gdrMh = mh; - *gdrDesc = md; - - *ptr = (T*)((char*)gdrMap + off); - if (devPtr) - *devPtr = (T*)(devMem + off + align); - - TRACE(mscclpp_INIT, "GDRCOPY : allocated devMem %p gdrMap %p offset %lx mh %lx mapSize %zi at %p", md->gdrDevMem, - md->gdrMap, md->gdrOffset, md->gdrMh.h, md->gdrMapSize, *ptr); - - return mscclppSuccess; - -finish: - CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (*ptr == nullptr) - WARN("Failed to CUDA calloc %ld bytes", nelem * sizeof(T)); - INFO(MSCCLPP_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr); - return result; -} -#define mscclppGdrCudaCalloc(...) mscclppGdrCudaCallocDebug(__VA_ARGS__, __FILE__, __LINE__) - -static mscclppResult_t mscclppGdrCudaFree(void* gdrDesc) -{ - gdr_mem_desc_t* md = (gdr_mem_desc_t*)gdrDesc; - MSCCLPPCHECK(wrap_gdr_unmap(mscclppGdrCopy, md->gdrMh, md->gdrMap, md->gdrMapSize)); - MSCCLPPCHECK(wrap_gdr_unpin_buffer(mscclppGdrCopy, md->gdrMh)); - CUDACHECK(cudaFree(md->gdrDevMem)); - free(md); - - return mscclppSuccess; -} - -#endif diff --git a/src/include/ib.h b/src/include/ib.h deleted file mode 100644 index 1d059b7c..00000000 --- a/src/include/ib.h +++ /dev/null @@ -1,86 +0,0 @@ -#ifndef MSCCLPP_IB_H_ -#define MSCCLPP_IB_H_ - -#include "mscclpp.h" -#include -#include -#include -#include - -#define MSCCLPP_IB_CQ_SIZE 1024 -#define MSCCLPP_IB_CQ_POLL_NUM 4 -#define MSCCLPP_IB_MAX_SENDS 64 -#define MSCCLPP_IB_MAX_DEVS 8 - -// MR info to be shared with the remote peer -struct mscclppIbMrInfo -{ - uint64_t addr; - uint32_t rkey; -}; - -// IB memory region -struct mscclppIbMr -{ - struct ibv_mr* mr; - void* buff; - struct mscclppIbMrInfo info; -}; - -// QP info to be shared with the remote peer -struct mscclppIbQpInfo -{ - uint16_t lid; - uint8_t port; - uint8_t linkLayer; - uint32_t qpn; - uint64_t spn; - ibv_mtu mtu; - uint64_t iid; - bool is_grh; -}; - -// IB queue pair -struct mscclppIbQp -{ - struct ibv_qp* qp; - struct mscclppIbQpInfo info; - struct ibv_send_wr* wrs; - struct ibv_sge* sges; - struct ibv_cq* cq; - struct ibv_wc* wcs; - int wrn; - - int rtr(const mscclppIbQpInfo* info); - int rts(); - int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); - int postSend(); - int postRecv(uint64_t wrId); - int pollCq(); -}; - -// Holds resources of a single IB device. -struct mscclppIbContext -{ - struct ibv_context* ctx; - struct ibv_pd* pd; - int* ports; - int nPorts; - struct mscclppIbQp* qps; - int nQps; - int maxQps; - struct mscclppIbMr* mrs; - int nMrs; - int maxMrs; -}; - -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName); -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx); -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1); -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr); - -#endif diff --git a/src/include/ib.hpp b/src/include/ib.hpp new file mode 100644 index 00000000..2fe9a447 --- /dev/null +++ b/src/include/ib.hpp @@ -0,0 +1,105 @@ +#ifndef MSCCLPP_IB_HPP_ +#define MSCCLPP_IB_HPP_ + +#include +#include +#include + +#define MSCCLPP_IB_CQ_SIZE 1024 +#define MSCCLPP_IB_CQ_POLL_NUM 1 +#define MSCCLPP_IB_MAX_SENDS 64 +#define MSCCLPP_IB_MAX_DEVS 8 + +namespace mscclpp { + +struct IbMrInfo { + uint64_t addr; + uint32_t rkey; +}; + +class IbMr { + public: + ~IbMr(); + + IbMrInfo getInfo() const; + const void* getBuff() const; + uint32_t getLkey() const; + + private: + IbMr(void* pd, void* buff, std::size_t size); + + void* mr; + void* buff; + std::size_t size; + + friend class IbCtx; +}; + +// QP info to be shared with the remote peer +struct IbQpInfo { + uint16_t lid; + uint8_t port; + uint8_t linkLayer; + uint32_t qpn; + uint64_t spn; + int mtu; + uint64_t iid; + bool is_grh; +}; + +class IbQp { + public: + ~IbQp(); + + void rtr(const IbQpInfo& info); + void rts(); + int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled); + int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled, unsigned int immData); + void postSend(); + void postRecv(uint64_t wrId); + int pollCq(); + + IbQpInfo& getInfo(); + const void* getWc(int idx) const; + + private: + IbQp(void* ctx, void* pd, int port); + + IbQpInfo info; + + void* qp; + void* cq; + void* wcs; + void* wrs; + void* sges; + int wrn; + + friend class IbCtx; +}; + +class IbCtx { + public: + IbCtx(const std::string& devName); + ~IbCtx(); + + IbQp* createQp(int port = -1); + const IbMr* registerMr(void* buff, std::size_t size); + + const std::string& getDevName() const; + + private: + bool isPortUsable(int port) const; + int getAnyActivePort() const; + + const std::string devName; + void* ctx; + void* pd; + std::list> qps; + std::list> mrs; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_IB_HPP_ diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 48544911..b57dc263 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -12,12 +12,25 @@ #define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 #include -#include + +#include +// #includa #ifdef __cplusplus extern "C" { #endif +struct alignas(16) mscclppDevConnSignalEpochId { + // every signal(), increaments this and either: + // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy + // 2) gpu thread directly writes it to remoteSignalEpochId->device + uint64_t device; + // signal() function triggers the cpu proxy thread to write to it + uint64_t proxy; +}; + +using mscclppBufferHandle_t = uint32_t; + /*************************************************************************************************************** * A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand. * The communication API is one-sided meaning that for every single data transfer, only one side @@ -80,39 +93,30 @@ extern "C" { * The two endpoint can concurrently use the same connection provided they are writing (puts) on different * indices in the registered buffer. **************************************************************************************************************/ -struct mscclppDevConn -{ +struct mscclppDevConn { #ifdef __CUDACC__ - __forceinline__ __device__ void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) - { + __forceinline__ __device__ void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) { fifo.push(mscclppData, dstDataOffset, srcDataOffset, dataSize); } - __forceinline__ __device__ void put(uint64_t dataOffset, uint64_t dataSize) - { - put(dataOffset, dataOffset, dataSize); - } + __forceinline__ __device__ void put(uint64_t dataOffset, uint64_t dataSize) { put(dataOffset, dataOffset, dataSize); } - __forceinline__ __device__ void signal() - { + __forceinline__ __device__ void signal() { epochIncrement(); fifo.push(mscclppFlag, 0, 0, 1); } - __forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) - { + __forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) { epochIncrement(); fifo.push(mscclppData | mscclppFlag, dstDataOffset, srcDataOffset, dataSize); } - __forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize) - { + __forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize) { putWithSignal(dataOffset, dataOffset, dataSize); } __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstDataOffset, uint64_t srcDataOffset, - uint64_t dataSize) - { + uint64_t dataSize) { epochIncrement(); uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize); while (*(volatile uint64_t*)&fifo.triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && @@ -120,13 +124,11 @@ struct mscclppDevConn ; } - __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize) - { + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dataOffset, uint64_t dataSize) { putWithSignalAndFlush(dataOffset, dataOffset, dataSize); } - __forceinline__ __device__ void flush() - { + __forceinline__ __device__ void flush() { uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1); // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. @@ -137,84 +139,101 @@ struct mscclppDevConn // Version that uses the SM directly to do the copy, instead of using the proxy thread like the functions above. __forceinline__ __device__ void putDirect(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize, - uint32_t threadId, uint32_t numThreads) - { + uint32_t threadId, uint32_t numThreads) { uint64_t* src = (uint64_t*)((char*)localBuff + srcDataOffset); uint64_t* dst = (uint64_t*)((char*)remoteBuff + dstDataOffset); // assume the memory is aligned to 8 bytes size_t nElem = - dataSize % sizeof(uint64_t) ? (dataSize + sizeof(uint64_t)) / sizeof(uint64_t) : dataSize / sizeof(uint64_t); + dataSize % sizeof(uint64_t) ? (dataSize + sizeof(uint64_t)) / sizeof(uint64_t) : dataSize / sizeof(uint64_t); for (size_t i = threadId; i < nElem; i += numThreads) { dst[i] = src[i]; } } __forceinline__ __device__ void putDirect(uint64_t dataOffset, uint64_t dataSize, uint32_t threadId, - uint32_t numThreads) - { + uint32_t numThreads) { putDirect(dataOffset, dataOffset, dataSize, threadId, numThreads); } - __forceinline__ __device__ void signalDirect() - { + __forceinline__ __device__ void signalDirect() { // This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the // incremented epoch id is visible. __threadfence_system(); epochIncrement(); - *(volatile uint64_t*)remoteEpochId = *sendEpochId; + *(volatile uint64_t*)&(remoteSignalEpochId->device) = localSignalEpochId->device; } - __forceinline__ __device__ void wait() - { - (*recvEpochId) += 1; - // printf("%llu %llu %llu\n", *(volatile uint64_t*)proxyEpochId, (*recvEpochId), *(volatile uint64_t*)sendEpochId); - while (*(volatile uint64_t*)proxyEpochId < (*recvEpochId)) + __forceinline__ __device__ void wait() { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) ; } - __forceinline__ __device__ void waitDirect() - { - (*recvEpochId) += 1; - while (*(volatile uint64_t*)directRecvEpochId < (*recvEpochId)) + __forceinline__ __device__ void waitDirect() { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->device) < (*waitEpochId)) ; } - __forceinline__ __device__ void epochIncrement() - { - *(volatile uint64_t*)sendEpochId += 1; - } + __forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(localSignalEpochId->device) += 1; } -#endif - int remoteRank; - int tag; - - void* localBuff; - uint64_t* sendEpochId; // this is read and written by the GPU - uint64_t* recvEpochId; // this is the expected recv epoch id. - uint64_t* directRecvEpochId; // this is read and written by remote GPU. - - void* remoteBuff; - uint64_t* remoteFlag; - uint64_t* remoteEpochId; - uint64_t* proxyEpochId; // this is only written by the proxy thread +#endif // __CUDACC__ // this is a concurrent fifo which is multiple threads from the device // can produce for and the sole proxy thread consumes it. struct mscclppConcurrentFifo fifo; + + int remoteRank; + int tag; + + // my local buffer + void* localBuff; + + struct mscclppDevConnSignalEpochId* localSignalEpochId; + // used by the signal() function directly from gpu + struct mscclppDevConnSignalEpochId* remoteSignalEpochId; + + // every wait(), increaments this and then the gpu waits for either: + // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread + // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread + uint64_t* waitEpochId; + + // my remote peer's buffer. only non-NULL with gpu's direct access + // gpu can directly write into it + void* remoteBuff; +}; + +// Host interface for mscclppDevCon functionality +struct mscclppHostConn { + virtual ~mscclppHostConn() = default; + virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0; + virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, + uint64_t dataSize) = 0; + virtual void signal() = 0; + virtual void wait() = 0; + virtual void flush() = 0; }; typedef struct mscclppComm* mscclppComm_t; typedef struct mscclppDevConn mscclppDevConn_t; +typedef struct mscclppHostConn mscclppHostConn_t; #define MSCCLPP_UNIQUE_ID_BYTES 128 -typedef struct -{ +typedef struct { char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId; +struct mscclppRegisteredMemoryP2P { + void* remoteBuff; + const void* IbMr; +}; + +struct mscclppRegisteredMemory { + std::vector p2p; +}; + /* Error type */ -typedef enum -{ +typedef enum { mscclppSuccess = 0, mscclppUnhandledCudaError = 1, mscclppSystemError = 2, @@ -236,10 +255,9 @@ typedef enum mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* uniqueId); /* Transport Types */ -typedef enum -{ +typedef enum { mscclppTransportP2P = 0, - mscclppTransportSHM = 1, // TODO(chhwang): not implemented yet + mscclppTransportSHM = 1, // TODO(chhwang): not implemented yet mscclppTransportIB = 2, } mscclppTransport_t; @@ -319,6 +337,40 @@ const char* mscclppGetErrorString(mscclppResult_t result); mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev = 0); +/* Connect to a remote rank. This function only prepares metadata for connection. The actual connection + * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection + * from rank i to remote rank j needs to have a counterpart from rank j to rank i. + * Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages + * and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has + * security risks if the devConn's accesses are given to a malicious process. + * + * This version does not register a buffer. Buffers should instead be registered with mscclppRegisterBuffer(). + * + * Inputs: + * comm: the communicator + * remoteRank: the rank of the remote process + * tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be + * used to identify the connection inside a GPU kernel. + * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) + * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. + */ +mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, + mscclppTransport_t transportType, const char* ibDev = 0); + +/* Register a buffer for use with a connection. + * + * Inputs: + * comm: the communicator + * connIdx: the index of the connection by order of calls to mscclppConnect/mscclppConnectWithoutBuffer + * localBuff: the local send/receive buffer + * buffSize: the size of the local buffer + * + * Outputs: + * handle: a handle to the buffer registration + */ +mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, + mscclppBufferHandle_t* handle); + /* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect() * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. * @@ -415,8 +467,35 @@ void mscclppDefaultLogHandler(const char* msg); */ mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler); +/* Register a buffer for RDMA. + * + * Outputs: + * regMem: the registered memory + * + * Inputs: + * comm: the communicator + * local_memory: the local buffer to be registered + * size: the size of the buffer + */ +mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size, + mscclppRegisteredMemory* regMem); + +/* Write to a registered buffer. + * + * Inputs: + * comm: the communicator + * regMem: the registered memory + * srcBuff: the source buffer + * size: the size of the buffer + * srcOffset: the offset of the source buffer + * dstOffset: the offset of the destination buffer + * stream: the CUDA stream + */ +mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff, + size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream); + #ifdef __cplusplus -} // end extern "C" +} // end extern "C" #endif -#endif // MSCCLPP_H_ +#endif // MSCCLPP_H_ diff --git a/src/include/mscclppfifo.h b/src/include/mscclppfifo.h index 78918fff..030220dd 100644 --- a/src/include/mscclppfifo.h +++ b/src/include/mscclppfifo.h @@ -7,12 +7,7 @@ extern "C" { #endif -typedef enum : uint64_t -{ - mscclppData = 0x1, - mscclppFlag = 0x2, - mscclppSync = 0x4 -} mscclppTriggerType_t; +typedef enum : uint64_t { mscclppData = 0x1, mscclppFlag = 0x2, mscclppSync = 0x4 } mscclppTriggerType_t; #define MSCCLPP_BITS_SIZE 32 #define MSCCLPP_BITS_OFFSET 32 @@ -23,17 +18,16 @@ typedef enum : uint64_t // the summation of number of bits must be 128 or less union alignas(16) mscclppTrigger { uint64_t value[2]; - struct - { + struct { // first 64 bits: value[0] uint64_t dataSize : MSCCLPP_BITS_SIZE; uint64_t srcDataOffset : MSCCLPP_BITS_OFFSET; - uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment // second 64 bits: value[1] uint64_t dstDataOffset : MSCCLPP_BITS_OFFSET; uint64_t connId : MSCCLPP_BITS_CONNID; uint64_t type : MSCCLPP_BITS_TYPE; - uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_CONNID - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_CONNID - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment } fields; }; @@ -49,16 +43,14 @@ typedef mscclppTrigger* mscclppTrigger_t; * push() function increments triggerFifoHead, proxyState->fifoTailHost is updated in proxy.cc:mscclppProxyService * and it occasionally flushes it to triggerFifoTail via a cudaMemcpyAsync. * - * Why douplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates + * Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates * for the tail as there is usually enough space for device threads to push their work into. */ -struct mscclppConcurrentFifo -{ +struct mscclppConcurrentFifo { #ifdef __CUDACC__ __forceinline__ __device__ uint64_t push(uint64_t type, uint64_t dstDataOffset, uint64_t srcDataOffset, - uint64_t dataSize) - { + uint64_t dataSize) { uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead, 1); while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail)) ; @@ -71,16 +63,16 @@ struct mscclppConcurrentFifo return curFifoHead; } -#endif // __CUDACC__ - mscclppTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements - uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused - // occasionally to device - uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device +#endif // __CUDACC__ + mscclppTrigger* triggerFifo; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements + uint64_t* triggerFifoTail; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused + // occasionally to device + uint64_t* triggerFifoHead; // Allocated on device. Only accessed by device int connId; }; #ifdef __cplusplus -} // end extern "C" +} // end extern "C" #endif -#endif // MSCCLPPFIFO_H_ +#endif // MSCCLPPFIFO_H_ diff --git a/src/include/param.h b/src/include/param.h deleted file mode 100644 index e7478807..00000000 --- a/src/include/param.h +++ /dev/null @@ -1,30 +0,0 @@ -/************************************************************************* - * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef MSCCLPP_PARAM_H_ -#define MSCCLPP_PARAM_H_ - -#include - -const char* userHomeDir(); -void setEnvFile(const char* fileName); -void initEnv(); - -void mscclppLoadParam(char const* env, int64_t deftVal, int64_t uninitialized, int64_t* cache); - -#define MSCCLPP_PARAM(name, env, deftVal) \ - int64_t mscclppParam##name() \ - { \ - constexpr int64_t uninitialized = INT64_MIN; \ - static_assert(deftVal != uninitialized, "default value cannot be the uninitialized value."); \ - static int64_t cache = uninitialized; \ - if (__builtin_expect(__atomic_load_n(&cache, __ATOMIC_RELAXED) == uninitialized, false)) { \ - mscclppLoadParam("MSCCLPP_" env, deftVal, uninitialized, &cache); \ - } \ - return cache; \ - } - -#endif diff --git a/src/include/proxy.h b/src/include/proxy.h index cf496f0f..17e92dfd 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -1,25 +1,28 @@ #ifndef MSCCLPP_PROXY_H_ #define MSCCLPP_PROXY_H_ -#include "comm.h" -#include "mscclpp.h" #include #include -#define MSCCLPP_PROXY_MAX_NUM (MSCCLPP_IB_MAX_DEVS + 1) // One is for a P2P proxy. +#include -typedef enum -{ +#include "comm.h" +#include "mscclpp.h" + +#define MSCCLPP_PROXY_MAX_NUM (MSCCLPP_IB_MAX_DEVS + 1) // One is for a P2P proxy. + +typedef enum { MSCCLPP_PROXY_RUN_STATE_IDLE = 0, MSCCLPP_PROXY_RUN_STATE_RUNNING, MSCCLPP_PROXY_RUN_STATE_EXITING, } mscclppProxyRunState_t; -struct mscclppProxyState -{ - mscclppTransport_t transportType; - pthread_t thread; - mscclppProxyRunState_t run; +struct mscclppProxyFifo { + mscclppResult_t create(); + mscclppResult_t destroy(); + mscclppResult_t poll(mscclppTrigger* trigger); + mscclppResult_t pop(); + mscclppResult_t flushTail(bool sync = false); // fifo cudaHostCalloc'ed that is produced by device and consumed by host mscclppTrigger* triggerFifo; @@ -45,10 +48,20 @@ struct mscclppProxyState // these updates are pushed to the device. uint64_t fifoTailHost; + // for transferring fifo tail + cudaStream_t stream; +}; + +struct mscclppProxyState { + mscclppTransport_t transportType; + pthread_t thread; + mscclppProxyRunState_t run; + int numaNodeToBind; - struct mscclppIbContext* ibContext; // For IB connection only - cudaStream_t p2pStream; // for P2P DMA engine only - cudaStream_t fifoStream; // for transferring fifo tail + mscclpp::IbCtx* ibContext; // For IB connection only + cudaStream_t p2pStream; // for P2P DMA engine only + + struct mscclppProxyFifo fifo; }; mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm); diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp new file mode 100644 index 00000000..be32e25a --- /dev/null +++ b/src/include/registered_memory.hpp @@ -0,0 +1,55 @@ +#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_ +#define MSCCLPP_REGISTERED_MEMORY_HPP_ + +#include + +#include +#include + +#include "communicator.hpp" +#include "ib.hpp" +#include "mscclpp.h" + +namespace mscclpp { + +struct TransportInfo { + Transport transport; + + // TODO: rewrite this using std::variant or something + bool ibLocal; + union { + struct { + cudaIpcMemHandle_t cudaIpcBaseHandle; + size_t cudaIpcOffsetFromBase; + }; + struct { + const IbMr* ibMr; + IbMrInfo ibMrInfo; + }; + }; +}; + +struct RegisteredMemory::Impl { + void* data; + size_t size; + int rank; + uint64_t hostHash; + TransportFlags transports; + std::vector transportInfos; + + Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); + Impl(const std::vector& data); + + TransportInfo& getTransportInfo(Transport transport) { + for (auto& entry : transportInfos) { + if (entry.transport == transport) { + return entry; + } + } + throw Error("Transport data not found", ErrorCode::InternalError); + } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_REGISTERED_MEMORY_HPP_ diff --git a/src/include/socket.h b/src/include/socket.h index 556c6bb8..f17f74a8 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -7,7 +7,6 @@ #ifndef MSCCLPP_SOCKET_H_ #define MSCCLPP_SOCKET_H_ -#include "mscclpp.h" #include #include #include @@ -16,9 +15,11 @@ #include #include +#include "mscclpp.h" + #define MAX_IFS 16 #define MAX_IF_NAME_SIZE 16 -#define SLEEP_INT 1000 // connection retry sleep interval in usec +#define SLEEP_INT 1000 // connection retry sleep interval in usec #define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV) #define MSCCLPP_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL @@ -29,8 +30,7 @@ union mscclppSocketAddress { struct sockaddr_in6 sin6; }; -enum mscclppSocketState -{ +enum mscclppSocketState { mscclppSocketStateNone = 0, mscclppSocketStateInitialized = 1, mscclppSocketStateAccepting = 2, @@ -44,8 +44,7 @@ enum mscclppSocketState mscclppSocketStateNum = 10 }; -enum mscclppSocketType -{ +enum mscclppSocketType { mscclppSocketTypeUnknown = 0, mscclppSocketTypeBootstrap = 1, mscclppSocketTypeProxy = 2, @@ -53,8 +52,7 @@ enum mscclppSocketType mscclppSocketTypeNetIb = 4 }; -struct mscclppSocket -{ +struct mscclppSocket { int fd; int acceptFd; int connectRetries; @@ -75,7 +73,7 @@ int mscclppFindInterfaceMatchSubnet(char* ifNames, union mscclppSocketAddress* l int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs); // Initialize a socket -mscclppResult_t mscclppSocketInit(struct mscclppSocket* sock, union mscclppSocketAddress* addr = NULL, +mscclppResult_t mscclppSocketInit(struct mscclppSocket* sock, const mscclppSocketAddress* addr = NULL, uint64_t magic = MSCCLPP_SOCKET_MAGIC, enum mscclppSocketType type = mscclppSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); diff --git a/src/include/utils.h b/src/include/utils.h index 3eff9842..f3318031 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -7,14 +7,12 @@ #ifndef MSCCLPP_UTILS_H_ #define MSCCLPP_UTILS_H_ -#include "alloc.h" -#include "checks.h" -#include "mscclpp.h" -#include -#include -#include #include -#include + +#include + +#include "alloc.h" +// #include "mscclpp.h" // int mscclppCudaCompCap(); @@ -31,8 +29,7 @@ uint64_t getHostHash(); uint64_t getPidHash(); mscclppResult_t getRandomData(void* buffer, size_t bytes); -struct netIf -{ +struct netIf { char prefix[64]; int port; }; @@ -40,11 +37,9 @@ struct netIf int parseStringList(const char* string, struct netIf* ifList, int maxList); bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact); -static long log2i(long n) -{ +static long log2i(long n) { long l = 0; - while (n >>= 1) - l++; + while (n >>= 1) l++; return l; } @@ -54,16 +49,13 @@ int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end); /* get any bytes of random data from /dev/urandom, return 0 if it succeeds; else * return -1 */ -inline mscclppResult_t getRandomData(void* buffer, size_t bytes) -{ +inline mscclppResult_t getRandomData(void* buffer, size_t bytes) { mscclppResult_t ret = mscclppSuccess; if (bytes > 0) { const size_t one = 1UL; FILE* fp = fopen("/dev/urandom", "r"); - if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) - ret = mscclppSystemError; - if (fp) - fclose(fp); + if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) ret = mscclppSystemError; + if (fp) fclose(fp); } return ret; } diff --git a/src/include/utils.hpp b/src/include/utils.hpp new file mode 100644 index 00000000..536d1d29 --- /dev/null +++ b/src/include/utils.hpp @@ -0,0 +1,40 @@ +#ifndef MSCCLPP_UTILS_HPP_ +#define MSCCLPP_UTILS_HPP_ + +#include + +#include + +namespace mscclpp { + +struct Timer { + std::chrono::steady_clock::time_point start; + + Timer() { start = std::chrono::steady_clock::now(); } + + int64_t elapsed() { + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end - start).count(); + } + + void reset() { start = std::chrono::steady_clock::now(); } + + void print(const char* name) { + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start).count(); + printf("%s: %ld us\n", name, elapsed); + } +}; + +struct ScopedTimer { + Timer timer; + const char* name; + + ScopedTimer(const char* name) : name(name) {} + + ~ScopedTimer() { timer.print(name); } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_UTILS_HPP_ diff --git a/src/init.cc b/src/init.cc deleted file mode 100644 index 9b3f7f62..00000000 --- a/src/init.cc +++ /dev/null @@ -1,635 +0,0 @@ -#include "bootstrap.h" -#include "config.h" -#include "core.h" -#if defined(MSCCLPP_USE_GDRCOPY) -#include "gdr.h" -#endif -#include "mscclpp.h" -#include -#include -#if defined(ENABLE_NPKIT) -#include "npkit/npkit.h" -#endif - -static uint64_t hashUniqueId(mscclppUniqueId const& id) -{ - char const* bytes = (char const*)&id; - uint64_t h = 0xdeadbeef; - for (int i = 0; i < (int)sizeof(mscclppUniqueId); i++) { - h ^= h >> 32; - h *= 0x8db3db47fa2994ad; - h += bytes[i]; - } - return h; -} - -pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER; -static bool initialized = false; -// static size_t maxLocalSizeBytes = 0; - -#if defined(MSCCLPP_USE_GDRCOPY) - -gdr_t mscclppGdrCopy = NULL; - -mscclppResult_t initGdrCopy() -{ - if (mscclppGdrCopy == NULL) { - mscclppGdrCopy = mscclppGdrInit(); - if (mscclppGdrCopy == NULL) { - WARN("GDR init failed"); - return mscclppSystemError; - } - } - return mscclppSuccess; -} - -#endif - -static mscclppResult_t mscclppInit() -{ - if (__atomic_load_n(&initialized, __ATOMIC_ACQUIRE)) - return mscclppSuccess; - pthread_mutex_lock(&initLock); - if (!initialized) { - // Always initialize bootstrap network - MSCCLPPCHECK(bootstrapNetInit()); - - __atomic_store_n(&initialized, true, __ATOMIC_RELEASE); - } - pthread_mutex_unlock(&initLock); - return mscclppSuccess; -} - -static std::string mscclppShmFileName(mscclppComm_t comm, int rank) -{ - std::stringstream ss; - ss << "mscclpp." << std::hex << comm->magic << "." << rank; - return ss.str(); -} - -MSCCLPP_API(mscclppResult_t, mscclppGetUniqueId, mscclppUniqueId* out); -mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) -{ - MSCCLPPCHECK(mscclppInit()); - // mscclppCHECK(PtrCheck(out, "GetUniqueId", "out")); - mscclppResult_t res = bootstrapGetUniqueId((struct mscclppBootstrapHandle*)out); - TRACE_CALL("mscclppGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out)); - return res; -} - -MSCCLPP_API(mscclppResult_t, mscclppBootstrapAllGather, mscclppComm_t comm, void* data, int size); -mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size) -{ - MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size)); - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppCommInitRank, mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank); -mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, const char* ipPortPair, int rank) -{ -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK(initGdrCopy()); -#endif - - mscclppResult_t res = mscclppSuccess; - mscclppComm_t _comm = NULL; - // uint64_t hash = getHostHash(); - // uint64_t *hashes; - // std::map hashToNode; - - MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail); - _comm->rank = rank; - _comm->nRanks = nranks; - _comm->devNumaNode = -1; - // We assume that the user has set the device to the intended one already - CUDACHECK(cudaGetDevice(&_comm->cudaDev)); - - MSCCLPPCHECK(bootstrapNetInit(ipPortPair)); - mscclppBootstrapHandle handle; - MSCCLPPCHECK(bootstrapGetUniqueId(&handle, rank == 0, ipPortPair)); - _comm->magic = handle.magic; - - MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t**)&_comm->abortFlag, 1), res, fail); - MSCCLPPCHECK(bootstrapInit(&handle, _comm)); - -#if defined(ENABLE_NPKIT) - // Init NPKit - MSCCLPPCHECK(NpKit::Init(_comm->rank)); -#endif - - *comm = _comm; - return res; -fail: - if (_comm) { - if (_comm->abortFlag) - mscclppCudaHostFree((void*)_comm->abortFlag); - free(_comm); - } - if (comm) - *comm = NULL; - return res; -} - -MSCCLPP_API(mscclppResult_t, mscclppCommInitRankFromId, mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank); -mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, mscclppUniqueId id, int rank) -{ -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK(initGdrCopy()); -#endif - - mscclppResult_t res = mscclppSuccess; - mscclppComm_t _comm = NULL; - mscclppBootstrapHandle* handle = (mscclppBootstrapHandle*)&id; - - MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail); - _comm->rank = rank; - _comm->nRanks = nranks; - // We assume that the user has set the device to the intended one already - CUDACHECK(cudaGetDevice(&_comm->cudaDev)); - - MSCCLPPCHECK(bootstrapNetInit()); - _comm->magic = handle->magic; - - MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t**)&_comm->abortFlag, 1), res, fail); - MSCCLPPCHECK(bootstrapInit(handle, _comm)); - -#if defined(ENABLE_NPKIT) - // Init NPKit - MSCCLPPCHECK(NpKit::Init(_comm->rank)); -#endif - - *comm = _comm; - return res; -fail: - if (_comm) { - if (_comm->abortFlag) - mscclppCudaHostFree((void*)_comm->abortFlag); - free(_comm); - } - if (comm) - *comm = NULL; - return res; -} - -MSCCLPP_API(mscclppResult_t, mscclppCommDestroy, mscclppComm_t comm); -mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) -{ -#if defined(ENABLE_NPKIT) - const char* npkitDumpDir = nullptr; -#endif - - if (comm == NULL) - return mscclppSuccess; - - for (int i = 0; i < comm->nConns; ++i) { - struct mscclppConn* conn = &comm->conns[i]; - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->proxyEpochId)); - } - - for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { - struct mscclppProxyState* proxyState = comm->proxyState[i]; - if (proxyState) { -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK(mscclppGdrCudaFree(proxyState->triggerFifoDesc)); -#else - MSCCLPPCHECK(mscclppCudaHostFree(proxyState->triggerFifo)); -#endif - MSCCLPPCHECK(mscclppCudaFree(proxyState->fifoHead)); -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK(mscclppGdrCudaFree(proxyState->fifoTailDesc)); -#else - MSCCLPPCHECK(mscclppCudaFree(proxyState->fifoTailDev)); -#endif - if (proxyState->p2pStream) - CUDACHECK(cudaStreamDestroy(proxyState->p2pStream)); - CUDACHECK(cudaStreamDestroy(proxyState->fifoStream)); - free(proxyState); - } - } - - for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { - if (comm->ibContext[i]) { - MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i])); - } - } - - for (int i = 0; i < comm->nConns; i++) { - struct mscclppConn* conn = &comm->conns[i]; - if (conn) { - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->sendEpochId)); - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->recvEpochId)); - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->directRecvEpochId)); - } - } - - if (comm->bootstrap) - MSCCLPPCHECK(bootstrapClose(comm->bootstrap)); - - mscclppCudaHostFree((void*)comm->abortFlag); - free(comm); - -#if defined(ENABLE_NPKIT) - // Dump NPKit events and shutdown - npkitDumpDir = getenv("NPKIT_DUMP_DIR"); - if (npkitDumpDir == nullptr) { - WARN("NPKIT_DUMP_DIR is empty"); - } else { - MSCCLPPCHECK(NpKit::Dump(npkitDumpDir)); - } - MSCCLPPCHECK(NpKit::Shutdown()); -#endif - - return mscclppSuccess; -} - -MSCCLPP_API(const char*, mscclppGetErrorString, mscclppResult_t code); -const char* mscclppGetErrorString(mscclppResult_t code) -{ - switch (code) { - case mscclppSuccess: - return "no error"; - case mscclppUnhandledCudaError: - return "unhandled cuda error"; - case mscclppSystemError: - return "unhandled system error"; - case mscclppInternalError: - return "internal error"; - case mscclppInvalidArgument: - return "invalid argument"; - case mscclppInvalidUsage: - return "invalid usage"; - case mscclppRemoteError: - return "remote process exited or there was a network error"; - case mscclppInProgress: - return "MSCCL++ operation in progress"; - default: - return "unknown result code"; - } -} - -MSCCLPP_API(mscclppResult_t, mscclppGetDeviceConnection, mscclppComm_t comm, int remoteRank, int tag, - mscclppDevConn_t** devConn); -mscclppResult_t mscclppGetDeviceConnection(mscclppComm_t comm, int remoteRank, int tag, mscclppDevConn_t** devConn) -{ - for (int i = 0; i < comm->nConns; i++) { - if (comm->devConns[i].remoteRank == remoteRank && comm->devConns[i].tag == tag) { - *devConn = &comm->devConns[i]; - return mscclppSuccess; - } - } - - return mscclppInvalidArgument; -} - -MSCCLPP_API(mscclppResult_t, mscclppGetAllDeviceConnections, mscclppComm_t comm, mscclppDevConn_t** devConns, - int* nConns); -mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, mscclppDevConn_t** devConns, int* nConns) -{ - *nConns = comm->nConns; - *devConns = comm->devConns; - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank, int tag, void* localBuff, - uint64_t buffSize, mscclppTransport_t transportType, const char* ibDev); -mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void* localBuff, uint64_t buffSize, - mscclppTransport_t transportType, const char* ibDev) -{ - // save this processes numa binding and set it to the one closest to the device - // so that all the allocation are close to the device - if (comm->devNumaNode == -1) { - // in case this is our first time - MSCCLPPCHECK(getDeviceNumaNode(comm->cudaDev, &comm->devNumaNode)); - INFO(MSCCLPP_INIT, "NUMA node of device %d is set to %d", comm->cudaDev, comm->devNumaNode); - } - // save numa node bitmask to change it back to user's numa node - mscclppNumaState curProcessState; - MSCCLPPCHECK(getNumaState(&curProcessState)); - // change to device's numa node so that the following allocation are close to the device - MSCCLPPCHECK(numaBind(comm->devNumaNode)); - - if (comm->nConns == MAXCONNECTIONS) { - WARN("Too many connections made"); - return mscclppInternalError; - } - struct mscclppConn* conn = &comm->conns[comm->nConns]; - conn->transport = transportType; - conn->buffSize = buffSize; - - conn->ibCtx = NULL; - conn->ibQp = NULL; - int ibDevIdx = -1; - if (transportType == mscclppTransportIB) { - // Check if an IB context exists - int firstNullIdx = -1; - for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { - if (comm->ibContext[i] == NULL) { - if (firstNullIdx == -1) { - firstNullIdx = i; - } - } else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) { - ibDevIdx = i; - break; - } - } - - // If not, create a new one - if (ibDevIdx == -1) { - // Create a new context. - ibDevIdx = firstNullIdx; - if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) { - WARN("Failed to create IB context"); - return mscclppInternalError; - } - } - // Set the ib context for this conn - conn->ibCtx = comm->ibContext[ibDevIdx]; - } else if (transportType == mscclppTransportP2P) { - // No allocation needed for P2P proxy - } else if (transportType == mscclppTransportSHM) { - WARN("Shared memory interconnection is not implemented yet!"); - return mscclppInternalError; - } else { - WARN("Unexpected connection type!"); - return mscclppInvalidUsage; - } - - // Find/create a proxy state for the given connection - struct mscclppProxyState* proxyState = NULL; - // First see if there is a matching context - // If not, find the first empty proxy - int firstEmptyProxyIndex = -1; - for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { - struct mscclppProxyState* curProxy = comm->proxyState[i]; - if (curProxy && (curProxy->transportType == transportType)) { - if ((transportType == mscclppTransportIB && curProxy->ibContext == conn->ibCtx) || - (transportType == mscclppTransportP2P)) { - proxyState = curProxy; - break; // we found the matching context - } - } - if (curProxy == NULL && firstEmptyProxyIndex == -1) { - firstEmptyProxyIndex = i; - } - } - - if (proxyState == NULL && firstEmptyProxyIndex == -1) { - WARN("Too many proxies have been allocated!"); - return mscclppInvalidUsage; - } - - // If we couldn't find a matching context, create one - if (proxyState == NULL) { - MSCCLPPCHECK(mscclppCalloc(&proxyState, 1)); -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK(mscclppGdrCudaCalloc(&proxyState->triggerFifo, &proxyState->triggerFifoDev, MSCCLPP_PROXY_FIFO_SIZE, - &proxyState->triggerFifoDesc)); -#else - MSCCLPPCHECK(mscclppCudaHostCalloc(&proxyState->triggerFifo, MSCCLPP_PROXY_FIFO_SIZE)); -#endif - MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->fifoHead, 1)); -#if defined(MSCCLPP_USE_GDRCOPY) - MSCCLPPCHECK( - mscclppGdrCudaCalloc(&proxyState->fifoTailDevHostPtr, &proxyState->fifoTailDev, 1, &proxyState->fifoTailDesc)); -#else - MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->fifoTailDev, 1)); -#endif - proxyState->fifoTailHost = 0; - - if (transportType == mscclppTransportIB) { - proxyState->ibContext = conn->ibCtx; - proxyState->p2pStream = NULL; - } else if (transportType == mscclppTransportP2P) { - proxyState->ibContext = NULL; - CUDACHECK(cudaStreamCreateWithFlags(&proxyState->p2pStream, cudaStreamNonBlocking)); - } - CUDACHECK(cudaStreamCreateWithFlags(&proxyState->fifoStream, cudaStreamNonBlocking)); - proxyState->numaNodeToBind = comm->devNumaNode; - - // INFO(MSCCLPP_INIT, "NUMA node for device %d is %d", cudaDev, *numaNode); - proxyState->transportType = transportType; - comm->proxyState[firstEmptyProxyIndex] = proxyState; - } - if (proxyState == NULL) { - // Cannot reach - WARN("Proxy allocation failed!"); - return mscclppInternalError; - } - - struct mscclppDevConn* devConn = &comm->devConns[comm->nConns]; - - conn->devConn = devConn; - conn->devConn->localBuff = localBuff; - MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->sendEpochId, 1)); - MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->recvEpochId, 1)); - MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->directRecvEpochId, 1)); - - conn->devConn->remoteRank = remoteRank; - conn->devConn->tag = tag; - conn->devConn->fifo.connId = comm->nConns; -#if defined(MSCCLPP_USE_GDRCOPY) - conn->devConn->fifo.triggerFifo = proxyState->triggerFifoDev; -#else - conn->devConn->fifo.triggerFifo = proxyState->triggerFifo; -#endif - conn->devConn->fifo.triggerFifoHead = proxyState->fifoHead; - conn->devConn->fifo.triggerFifoTail = proxyState->fifoTailDev; - - comm->nConns++; - // change the numa binding back to user's - MSCCLPPCHECK(setNumaState(curProcessState)); - - return mscclppSuccess; -} - -struct connInfo -{ - cudaIpcMemHandle_t handleBuff; - cudaIpcMemHandle_t handleFlag; - cudaIpcMemHandle_t handleProxyFlag; - cudaIpcMemHandle_t handleRemoteEpochId; - mscclppIbQpInfo infoQp; - mscclppIbMrInfo infoBuffMr; - mscclppIbMrInfo infoLocalFlagMr; - mscclppIbMrInfo infoProxyFlagMr; -}; - -mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/) -{ - if (connInfo == NULL || conn == NULL) { - WARN("connInfo or connection cannot be null"); - return mscclppInternalError; - } - struct mscclppDevConn* devConn = conn->devConn; - MSCCLPPCHECK(mscclppCudaCalloc(&devConn->proxyEpochId, 1)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleProxyFlag, devConn->proxyEpochId)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleBuff, devConn->localBuff)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleFlag, devConn->sendEpochId)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleRemoteEpochId, devConn->directRecvEpochId)); - return mscclppSuccess; -} - -mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*output*/) -{ - if (connInfo == NULL || conn == NULL) { - WARN("ipcHandles or connection cannot be null"); - return mscclppInternalError; - } - CUDACHECK( - cudaIpcOpenMemHandle((void**)&conn->devConn->remoteBuff, connInfo->handleBuff, cudaIpcMemLazyEnablePeerAccess)); - CUDACHECK( - cudaIpcOpenMemHandle((void**)&conn->devConn->remoteFlag, connInfo->handleFlag, cudaIpcMemLazyEnablePeerAccess)); - CUDACHECK(cudaIpcOpenMemHandle((void**)&conn->devConn->remoteEpochId, connInfo->handleRemoteEpochId, - cudaIpcMemLazyEnablePeerAccess)); - CUDACHECK( - cudaIpcOpenMemHandle((void**)&conn->remoteProxyFlag, connInfo->handleProxyFlag, cudaIpcMemLazyEnablePeerAccess)); - return mscclppSuccess; -} - -mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/) -{ - if (connInfo == NULL || conn == NULL) { - WARN("connInfo or connection cannot be null"); - return mscclppInternalError; - } - struct mscclppDevConn* devConn = conn->devConn; - devConn->remoteBuff = NULL; - devConn->remoteFlag = NULL; - MSCCLPPCHECK(mscclppCudaCalloc(&devConn->proxyEpochId, 1)); - - struct mscclppIbContext* ibCtx = conn->ibCtx; - if (conn->ibQp == NULL) { - MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp)); - } - // TODO(chhwang): can we register only one MR for the following three? - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->sendEpochId, sizeof(uint64_t), &conn->ibLocalFlagMr)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->proxyEpochId, sizeof(uint64_t), &conn->ibProxyFlagMr)); - connInfo->infoQp = conn->ibQp->info; - connInfo->infoBuffMr = conn->ibBuffMr->info; - connInfo->infoLocalFlagMr = conn->ibLocalFlagMr->info; - connInfo->infoProxyFlagMr = conn->ibProxyFlagMr->info; - return mscclppSuccess; -} - -mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, struct mscclppConn* conn /*output*/) -{ - if (connInfo == NULL || conn == NULL) { - WARN("ipcHandles or connection cannot be null"); - return mscclppInternalError; - } - if (conn->ibQp->rtr(&connInfo->infoQp) != 0) { - WARN("Failed to transition QP to RTR"); - return mscclppInvalidUsage; - } - if (conn->ibQp->rts() != 0) { - WARN("Failed to transition QP to RTS"); - return mscclppInvalidUsage; - } - conn->ibBuffMrInfo = connInfo->infoBuffMr; - conn->ibLocalFlagMrInfo = connInfo->infoLocalFlagMr; - conn->ibProxyFlagMrInfo = connInfo->infoProxyFlagMr; - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm); -mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) -{ - // Send info to peers - for (int i = 0; i < comm->nConns; ++i) { - struct mscclppConn* conn = &comm->conns[i]; - - struct connInfo cInfo; - if (conn->transport == mscclppTransportP2P) { - MSCCLPPCHECK(mscclppP2pConnectionSetupStart(&cInfo, conn)); - } else if (conn->transport == mscclppTransportIB) { - MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn)); - } - // TODO: from saemal: do we possibly deadlock if there are too many outstanding sends? - MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo))); - } - - // Recv info from peers - for (int i = 0; i < comm->nConns; ++i) { - struct mscclppConn* conn = &comm->conns[i]; - struct connInfo cInfo; - MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo))); - if (conn->transport == mscclppTransportP2P) { - MSCCLPPCHECK(mscclppP2pConnectionSetupEnd(&cInfo, conn)); - } else if (conn->transport == mscclppTransportIB) { - MSCCLPPCHECK(mscclppIbConnectionSetupEnd(&cInfo, conn)); - } - } - - // a barrier to ensure setup on all gpus are done and we can return to the user - MSCCLPPCHECK(mscclppBootstrapBarrier(comm)); - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppProxyLaunch, mscclppComm_t comm); -mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm) -{ - MSCCLPPCHECK(mscclppProxyCreate(comm)); - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppBootstrapBarrier, mscclppComm_t comm); -mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm) -{ - int* tmp = new int[comm->nRanks]; - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - delete[] tmp; - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppProxyStop, mscclppComm_t comm); -mscclppResult_t mscclppProxyStop(mscclppComm_t comm) -{ - // a barrier to make sure all ranks are done with their work before stopping the proxy - MSCCLPPCHECK(mscclppBootstrapBarrier(comm)); - - MSCCLPPCHECK(mscclppProxyDestroy(comm)); - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppCommRank, mscclppComm_t comm, int* rank); -mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank) -{ - if (comm == NULL || rank == NULL) { - WARN("comm or rank cannot be null"); - return mscclppInvalidUsage; - } - *rank = comm->rank; - return mscclppSuccess; -} - -MSCCLPP_API(mscclppResult_t, mscclppCommSize, mscclppComm_t comm, int* size); -mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) -{ - if (comm == NULL || size == NULL) { - WARN("comm or size cannot be null"); - return mscclppInvalidUsage; - } - *size = comm->nRanks; - return mscclppSuccess; -} - -MSCCLPP_API(void, mscclppDefaultLogHandler, const char* msg); -void mscclppDefaultLogHandler(const char* msg) -{ - mscclppDebugDefaultLogHandler(msg); -} - -MSCCLPP_API(mscclppResult_t, mscclppSetLogHandler, mscclppLogHandler_t handler); -mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler) -{ - return mscclppDebugSetLogHandler(handler); -} - -MSCCLPP_API(mscclppResult_t, mscclppSetBootstrapConnTimeout, int timeout); -mscclppResult_t mscclppSetBootstrapConnTimeout(int timeout) -{ - mscclppConfig* config = mscclppConfig::getInstance(); - config->setBootstrapConnectionTimeoutConfig(timeout); - return mscclppSuccess; -} \ No newline at end of file diff --git a/src/misc/npkit.cc b/src/npkit/npkit.cc similarity index 93% rename from src/misc/npkit.cc rename to src/npkit/npkit.cc index 4a7eb849..49ee7a12 100644 --- a/src/misc/npkit.cc +++ b/src/npkit/npkit.cc @@ -1,9 +1,12 @@ -#include -#include +#include "npkit.h" + +#include #include +#include +#include + #include "alloc.h" -#include "npkit/npkit.h" uint64_t NpKit::rank_ = 0; @@ -15,8 +18,7 @@ NpKitEventCollectContext* NpKit::cpu_collect_contexts_ = nullptr; uint64_t NpKit::cpu_base_system_timestamp_ = 0; uint64_t NpKit::cpu_base_steady_timestamp_ = 0; -mscclppResult_t NpKit::Init(int rank) -{ +mscclppResult_t NpKit::Init(int rank) { uint64_t i = 0; NpKitEventCollectContext ctx; ctx.event_buffer_head = 0; @@ -46,8 +48,7 @@ mscclppResult_t NpKit::Init(int rank) return mscclppSuccess; } -mscclppResult_t NpKit::Dump(const std::string& dump_dir) -{ +mscclppResult_t NpKit::Dump(const std::string& dump_dir) { uint64_t i = 0; std::string dump_file_path; @@ -112,8 +113,7 @@ mscclppResult_t NpKit::Dump(const std::string& dump_dir) return mscclppSuccess; } -mscclppResult_t NpKit::Shutdown() -{ +mscclppResult_t NpKit::Shutdown() { uint64_t i = 0; // Free CPU event data structures @@ -133,13 +133,9 @@ mscclppResult_t NpKit::Shutdown() return mscclppSuccess; } -NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts() -{ - return gpu_collect_contexts_; -} +NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts() { return gpu_collect_contexts_; } -void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id) -{ +void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id) { uint64_t event_buffer_head = cpu_collect_contexts_[channel_id].event_buffer_head; if (event_buffer_head < kMaxNumCpuEventsPerBuffer) { NpKitEvent& event = cpu_collect_contexts_[channel_id].event_buffer[event_buffer_head]; @@ -151,8 +147,7 @@ void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t } } -uint64_t NpKit::GetCpuTimestamp() -{ +uint64_t NpKit::GetCpuTimestamp() { uint64_t cpu_curr_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count(); return cpu_base_steady_timestamp_ + (cpu_curr_steady_timestamp_ - cpu_base_steady_timestamp_); } diff --git a/src/include/npkit/npkit.h b/src/npkit/npkit.h similarity index 90% rename from src/include/npkit/npkit.h rename to src/npkit/npkit.h index a0691afd..c15bb812 100644 --- a/src/include/npkit/npkit.h +++ b/src/npkit/npkit.h @@ -2,16 +2,13 @@ #define NPKIT_H_ #include -#include -#include +#include "mscclpp.h" +#include "npkit_event.h" +#include "npkit_struct.h" -#include "npkit/npkit_event.h" -#include "npkit/npkit_struct.h" - -class NpKit -{ -public: +class NpKit { + public: static const uint64_t kNumGpuEventBuffers = 512; static const uint64_t kNumCpuEventBuffers = 32; @@ -24,9 +21,9 @@ public: static NpKitEventCollectContext* GetGpuEventCollectContexts(); +#ifdef __CUDACC__ static inline __device__ void CollectGpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, - NpKitEventCollectContext* ctx) - { + NpKitEventCollectContext* ctx) { uint64_t event_buffer_head = ctx->event_buffer_head; if (event_buffer_head < kMaxNumGpuEventsPerBuffer) { NpKitEvent& event = ctx->event_buffer[event_buffer_head]; @@ -37,12 +34,13 @@ public: ctx->event_buffer_head++; } } +#endif // __CUDACC__ static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id); static uint64_t GetCpuTimestamp(); -private: + private: // 64K * 512 * 16B = 512MB per GPU static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16; diff --git a/src/include/npkit/npkit_event.h b/src/npkit/npkit_event.h similarity index 100% rename from src/include/npkit/npkit_event.h rename to src/npkit/npkit_event.h diff --git a/src/include/npkit/npkit_struct.h b/src/npkit/npkit_struct.h similarity index 87% rename from src/include/npkit/npkit_struct.h rename to src/npkit/npkit_struct.h index 2fc19821..a18e8798 100644 --- a/src/include/npkit/npkit_struct.h +++ b/src/npkit/npkit_struct.h @@ -7,8 +7,7 @@ union NpKitEvent { uint64_t bits[2]; - struct - { + struct { uint64_t type : 8; uint64_t size : 32; uint64_t rsvd : 24; @@ -16,8 +15,7 @@ union NpKitEvent { } fields; }; -struct NpKitEventCollectContext -{ +struct NpKitEventCollectContext { NpKitEvent* event_buffer; uint64_t event_buffer_head; }; diff --git a/src/param.cc b/src/param.cc deleted file mode 100644 index 2af48084..00000000 --- a/src/param.cc +++ /dev/null @@ -1,90 +0,0 @@ -/************************************************************************* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "param.h" -#include "debug.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -const char* userHomeDir() -{ - struct passwd* pwUser = getpwuid(getuid()); - return pwUser == NULL ? NULL : pwUser->pw_dir; -} - -void setEnvFile(const char* fileName) -{ - FILE* file = fopen(fileName, "r"); - if (file == NULL) - return; - - char* line = NULL; - char envVar[1024]; - char envValue[1024]; - size_t n = 0; - ssize_t read; - while ((read = getline(&line, &n, file)) != -1) { - if (line[read - 1] == '\n') - line[read - 1] = '\0'; - int s = 0; // Env Var Size - while (line[s] != '\0' && line[s] != '=') - s++; - if (line[s] == '\0') - continue; - strncpy(envVar, line, std::min(1023, s)); - envVar[s] = '\0'; - s++; - strncpy(envValue, line + s, 1023); - envValue[1023] = '\0'; - setenv(envVar, envValue, 0); - // printf("%s : %s->%s\n", fileName, envVar, envValue); - } - if (line) - free(line); - fclose(file); -} - -void initEnv() -{ - char confFilePath[1024]; - const char* userDir = userHomeDir(); - if (userDir) { - sprintf(confFilePath, "%s/.mscclpp.conf", userDir); - setEnvFile(confFilePath); - } - sprintf(confFilePath, "/etc/mscclpp.conf"); - setEnvFile(confFilePath); -} - -void mscclppLoadParam(char const* env, int64_t deftVal, int64_t uninitialized, int64_t* cache) -{ - static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - pthread_mutex_lock(&mutex); - if (__atomic_load_n(cache, __ATOMIC_RELAXED) == uninitialized) { - char* str = getenv(env); - int64_t value = deftVal; - if (str && strlen(str) > 0) { - errno = 0; - value = strtoll(str, nullptr, 0); - if (errno) { - value = deftVal; - INFO(MSCCLPP_ALL, "Invalid value %s for %s, using default %lld.", str, env, (long long)deftVal); - } else { - INFO(MSCCLPP_ALL, "%s set by environment to %lld.", env, (long long)value); - } - } - __atomic_store_n(cache, value, __ATOMIC_RELAXED); - } - pthread_mutex_unlock(&mutex); -} diff --git a/src/proxy.cc b/src/proxy.cc index c126d639..8a066279 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -1,284 +1,101 @@ -#include "alloc.h" -#include "checks.h" -#include "comm.h" -#include "debug.h" -#include "ib.h" -#include "socket.h" - -#include -#include -#include +#include +#include +#include #include -#include "npkit/npkit.h" +#include "api.h" +#include "utils.h" +#include "utils.hpp" -#define MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD 100 +namespace mscclpp { -#define PROXYCUDACHECK(cmd) \ - do { \ - cudaError_t err = cmd; \ - if (err != cudaSuccess) { \ - WARN("CUDA error from proxy: %s", cudaGetErrorString(err)); \ - return NULL; \ - } \ - } while (false) +const int ProxyStopCheckPeriod = 1000; -#define PROXYMSCCLPPCHECK(call) \ - do { \ - mscclppResult_t res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - /* Print the back trace*/ \ - if (mscclppDebugNoWarn == 0) \ - INFO(MSCCLPP_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \ - return NULL; \ - } \ - } while (0); +const int ProxyFlushPeriod = 4; -struct proxyArgs -{ - struct mscclppComm* comm; - struct mscclppProxyState* proxyState; +struct Proxy::Impl { + ProxyHandler handler; + std::function threadInit; + HostProxyFifo fifo; + std::thread service; + std::atomic_bool running; + + Impl(ProxyHandler handler, std::function threadInit) + : handler(handler), threadInit(threadInit), running(false) {} }; -static void readTrigger(mscclppTrigger* dst, mscclppTrigger* src) -{ - __m128i xmm0 = _mm_load_si128((__m128i*)src); - _mm_store_si128((__m128i*)dst, xmm0); +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function threadInit) { + pimpl = std::make_unique(handler, threadInit); } -#if defined(ENABLE_NPKIT) +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) {} -static void npkitInitReqIds(struct mscclppComm* comm) -{ - for (int i = 0; i < comm->nConns; i++) { - struct mscclppConn* conn = &comm->conns[i]; - conn->npkitUsedReqIds.resize(0); - conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS); - for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) { - conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1; - } +MSCCLPP_API_CPP Proxy::~Proxy() { + if (pimpl) { + stop(); } } -static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size, int channelId) -{ - uint64_t reqId = 0; - if (conn->npkitFreeReqIds.size() == 0) { - reqId = conn->npkitUsedReqIds.size(); - } else { - reqId = conn->npkitFreeReqIds.back(); - conn->npkitFreeReqIds.pop_back(); - } - conn->npkitUsedReqIds.push_back(reqId); - NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId); -} +MSCCLPP_API_CPP void Proxy::start() { + pimpl->running = true; + pimpl->service = std::thread([this] { + pimpl->threadInit(); -static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int channelId) -{ - while (conn->npkitUsedReqIds.size()) { - uint64_t reqId = conn->npkitUsedReqIds.back(); - NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId); - conn->npkitFreeReqIds.push_back(reqId); - conn->npkitUsedReqIds.pop_back(); - } -} + ProxyHandler handler = this->pimpl->handler; + HostProxyFifo& fifo = this->pimpl->fifo; + std::atomic_bool& running = this->pimpl->running; + ProxyTrigger trigger; -#else + int runCnt = ProxyStopCheckPeriod; + uint64_t flushCnt = 0; + for (;;) { + if (runCnt-- == 0) { + runCnt = ProxyStopCheckPeriod; + if (!running) { + break; + } + } + // Poll to see if we are ready to send anything + fifo.poll(&trigger); + if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers + continue; // there is one in progress + } -#define npkitInitReqIds(comm) + ProxyHandlerResult result = handler(trigger); -#define npkitCollectEntryEvent(conn, type, size, channelId) + // Send completion: reset only the high 64 bits + fifo.pop(); + // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure + // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush + // request. + if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) { + // TODO: relocate this check: || (trigger.fields.type & mscclppSync) + fifo.flushTail(); + } -#define npkitCollectExitEvents(conn, type, channelId) - -#endif - -void* mscclppProxyService(void* _args) -{ - struct proxyArgs* args = (struct proxyArgs*)_args; - struct mscclppComm* comm = args->comm; - - // from this point on, proxy thread will stay close to the device - PROXYCUDACHECK(cudaSetDevice(comm->cudaDev)); - PROXYMSCCLPPCHECK(numaBind(comm->devNumaNode)); - - volatile mscclppProxyRunState_t* run = &args->proxyState->run; - mscclppTrigger* fifo = args->proxyState->triggerFifo; - uint64_t* fifoTail = &args->proxyState->fifoTailHost; -#if defined(MSCCLPP_USE_GDRCOPY) - volatile uint64_t* fifoTailDevPtr = args->proxyState->fifoTailDevHostPtr; -#else - uint64_t* fifoTailDevPtr = args->proxyState->fifoTailDev; -#endif - uint64_t fifoTailCached = *fifoTail; - mscclppTrigger trigger; - mscclppIbContext* ibCtx = args->proxyState->ibContext; - cudaStream_t p2pStream = args->proxyState->p2pStream; -#if !defined(MSCCLPP_USE_GDRCOPY) - cudaStream_t fifoStream = args->proxyState->fifoStream; -#endif - bool isP2pProxy = (ibCtx == nullptr); - free(_args); // allocated in mscclppProxyCreate - - npkitInitReqIds(comm); - - int counter = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD; - for (;;) { - if (counter-- == 0) { - counter = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD; - if (*run != MSCCLPP_PROXY_RUN_STATE_RUNNING) { + if (result == ProxyHandlerResult::Stop) { break; } } - // Poll to see if we are ready to send anything - readTrigger(&trigger, &fifo[fifoTailCached % MSCCLPP_PROXY_FIFO_SIZE]); - if (trigger.value[0] == 0) { - continue; // there is one in progreess - } - struct mscclppConn* conn = &comm->conns[trigger.fields.connId]; - int ret = 0; - // Iterate over what send is needed - if (trigger.fields.type & mscclppData) { - if (isP2pProxy) { - void* srcBuff = (void*)((char*)conn->devConn->localBuff + trigger.fields.srcDataOffset); - void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + trigger.fields.dstDataOffset); - PROXYCUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, trigger.fields.dataSize, cudaMemcpyDeviceToDevice, p2pStream)); - npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)trigger.fields.dataSize, - trigger.fields.connId); - } else { - conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, - /*wrId=*/0, /*srcOffset=*/trigger.fields.srcDataOffset, - /*dstOffset=*/trigger.fields.dstDataOffset, - /*signaled=*/false); - if ((ret = conn->ibQp->postSend()) != 0) { - // Return value is errno. - WARN("data postSend failed: errno %d", ret); - } - npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)trigger.fields.dataSize, - trigger.fields.connId); - } - } - if (trigger.fields.type & mscclppFlag) { - if (isP2pProxy) { - PROXYCUDACHECK(cudaMemcpyAsync(conn->remoteProxyFlag, conn->devConn->sendEpochId, sizeof(uint64_t), - cudaMemcpyDeviceToDevice, p2pStream)); - npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t), - trigger.fields.connId); - } else { - // My local flag is copied to the peer's proxy flag - conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibProxyFlagMrInfo, sizeof(uint64_t), - /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/0, /*signaled=*/true); - if ((ret = conn->ibQp->postSend()) != 0) { - WARN("flag postSend failed: errno %d", ret); - } - npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t), trigger.fields.connId); - } - } - // Wait for completion - if (trigger.fields.type & mscclppSync) { - if (isP2pProxy) { - PROXYCUDACHECK(cudaStreamSynchronize(p2pStream)); - npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT, trigger.fields.connId); - } else { - int rank = comm->rank; - bool isWaiting = true; - while (isWaiting) { - int wcNum = conn->ibQp->pollCq(); - if (wcNum < 0) { - WARN("rank %d pollCq failed: errno %d", rank, errno); - continue; - } - for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &conn->ibQp->wcs[i]; - if (wc->status != IBV_WC_SUCCESS) { - WARN("rank %d wc status %d", rank, wc->status); - continue; - } - if (wc->qp_num != conn->ibQp->qp->qp_num) { - WARN("rank %d got wc of unknown qp_num %d", rank, wc->qp_num); - continue; - } - if (wc->opcode == IBV_WC_RDMA_WRITE) { - isWaiting = false; - break; - } - } - } - npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT, trigger.fields.connId); - } - } - - // Send completion: reset only the high 64 bits - *(volatile uint64_t*)(&fifo[fifoTailCached % MSCCLPP_PROXY_FIFO_SIZE]) = 0; - fifoTailCached++; - // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure - // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush - // request. - if (((fifoTailCached % MSCCLPP_PROXY_FIFO_FLUSH_COUNTER) == 0) || (trigger.fields.type & mscclppSync)) { -#if defined(MSCCLPP_USE_GDRCOPY) - *fifoTailDevPtr = fifoTailCached; -#else - PROXYCUDACHECK( - cudaMemcpyAsync(fifoTailDevPtr, &fifoTailCached, sizeof(uint64_t), cudaMemcpyHostToDevice, fifoStream)); -#endif - } - } - *fifoTail = fifoTailCached; - - // make sure the tail is flushed before we shut the proxy -#if defined(MSCCLPP_USE_GDRCOPY) - *fifoTailDevPtr = fifoTailCached; -#else - PROXYCUDACHECK( - cudaMemcpyAsync(fifoTailDevPtr, &fifoTailCached, sizeof(uint64_t), cudaMemcpyHostToDevice, fifoStream)); - PROXYCUDACHECK(cudaStreamSynchronize(fifoStream)); -#endif - if (isP2pProxy) { - PROXYCUDACHECK(cudaStreamSynchronize(p2pStream)); - } - *run = MSCCLPP_PROXY_RUN_STATE_IDLE; - return NULL; + // make sure the tail is flushed before we shut the proxy + fifo.flushTail(/*sync=*/true); + // TODO: do these need to run? + // bool isP2pProxy = (proxyState->ibContext == nullptr); + // if (isP2pProxy) { + // cudaStream_t p2pStream = proxyState->p2pStream; + // PROXYCUDACHECK(cudaStreamSynchronize(p2pStream)); + // } + }); } -mscclppResult_t mscclppProxyCreate(struct mscclppComm* comm) -{ - for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { - struct mscclppProxyState* proxyState = comm->proxyState[i]; - if (proxyState == NULL) - break; - - struct proxyArgs* args; - MSCCLPPCHECK(mscclppCalloc(&args, 1)); - args->comm = comm; - args->proxyState = proxyState; - - proxyState->run = MSCCLPP_PROXY_RUN_STATE_RUNNING; - pthread_create(&proxyState->thread, NULL, mscclppProxyService, args); - if (proxyState->transportType == mscclppTransportP2P) { - mscclppSetThreadName(proxyState->thread, "MSCCLPP Service P2P - %02d", comm->cudaDev); - } else if (proxyState->transportType == mscclppTransportIB) { - mscclppSetThreadName(proxyState->thread, "MSCCLPP Service IB - %02d", i); - } +MSCCLPP_API_CPP void Proxy::stop() { + pimpl->running = false; + if (pimpl->service.joinable()) { + pimpl->service.join(); } - return mscclppSuccess; } -mscclppResult_t mscclppProxyDestroy(struct mscclppComm* comm) -{ - for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { - struct mscclppProxyState* proxyState = comm->proxyState[i]; - if (proxyState == NULL) - break; +MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { return pimpl->fifo; } - volatile int* run = (volatile int*)&proxyState->run; - if (*run == MSCCLPP_PROXY_RUN_STATE_IDLE) { - continue; - } - *run = MSCCLPP_PROXY_RUN_STATE_EXITING; - while (*run == MSCCLPP_PROXY_RUN_STATE_EXITING && *comm->abortFlag == 0) { - usleep(1000); - } - } - return mscclppSuccess; -} +} // namespace mscclpp diff --git a/src/registered_memory.cc b/src/registered_memory.cc new file mode 100644 index 00000000..4781ba61 --- /dev/null +++ b/src/registered_memory.cc @@ -0,0 +1,143 @@ +#include "registered_memory.hpp" + +#include + +#include + +#include "api.h" +#include "checks.hpp" +#include "utils.h" + +namespace mscclpp { + +RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) + : data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) { + if (transports.has(Transport::CudaIpc)) { + TransportInfo transportInfo; + transportInfo.transport = Transport::CudaIpc; + cudaIpcMemHandle_t handle; + + void* baseDataPtr; + size_t baseDataSize; // dummy + CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); + CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); + // TODO: bug with offset of base? + transportInfo.cudaIpcBaseHandle = handle; + transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr; + this->transportInfos.push_back(transportInfo); + } + if ((transports & AllIBTransports).any()) { + auto addIb = [&](Transport ibTransport) { + TransportInfo transportInfo; + transportInfo.transport = ibTransport; + const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); + transportInfo.ibMr = mr; + transportInfo.ibLocal = true; + transportInfo.ibMrInfo = mr->getInfo(); + this->transportInfos.push_back(transportInfo); + INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size); + }; + if (transports.has(Transport::IB0)) addIb(Transport::IB0); + if (transports.has(Transport::IB1)) addIb(Transport::IB1); + if (transports.has(Transport::IB2)) addIb(Transport::IB2); + if (transports.has(Transport::IB3)) addIb(Transport::IB3); + if (transports.has(Transport::IB4)) addIb(Transport::IB4); + if (transports.has(Transport::IB5)) addIb(Transport::IB5); + if (transports.has(Transport::IB6)) addIb(Transport::IB6); + if (transports.has(Transport::IB7)) addIb(Transport::IB7); + } +} + +MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) {} + +MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; + +MSCCLPP_API_CPP void* RegisteredMemory::data() { return pimpl->data; } + +MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } + +MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; } + +MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; } + +MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { + std::vector result; + std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); + if (pimpl->transportInfos.size() > std::numeric_limits::max()) { + throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError); + } + int8_t transportCount = pimpl->transportInfos.size(); + std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); + for (auto& entry : pimpl->transportInfos) { + std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); + if (entry.transport == Transport::CudaIpc) { + std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase), + std::back_inserter(result)); + } else if (AllIBTransports.has(entry.transport)) { + std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); + } else { + throw mscclpp::Error("Unknown transport", ErrorCode::InternalError); + } + } + return result; +} + +MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { + return RegisteredMemory(std::make_shared(data)); +} + +RegisteredMemory::Impl::Impl(const std::vector& serialization) { + auto it = serialization.begin(); + std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); + it += sizeof(this->size); + std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); + it += sizeof(this->rank); + std::copy_n(it, sizeof(this->hostHash), reinterpret_cast(&this->hostHash)); + it += sizeof(this->hostHash); + std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); + it += sizeof(this->transports); + int8_t transportCount; + std::copy_n(it, sizeof(transportCount), reinterpret_cast(&transportCount)); + it += sizeof(transportCount); + for (int i = 0; i < transportCount; ++i) { + TransportInfo transportInfo; + std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); + it += sizeof(transportInfo.transport); + if (transportInfo.transport == Transport::CudaIpc) { + std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), + reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); + it += sizeof(transportInfo.cudaIpcBaseHandle); + std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), + reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); + it += sizeof(transportInfo.cudaIpcOffsetFromBase); + } else if (AllIBTransports.has(transportInfo.transport)) { + std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast(&transportInfo.ibMrInfo)); + it += sizeof(transportInfo.ibMrInfo); + transportInfo.ibLocal = false; + } else { + throw mscclpp::Error("Unknown transport", ErrorCode::InternalError); + } + this->transportInfos.push_back(transportInfo); + } + if (it != serialization.end()) { + throw mscclpp::Error("Serialization failed", ErrorCode::InternalError); + } + + if (transports.has(Transport::CudaIpc)) { + uint64_t localHostHash = getHostHash(); + if (localHostHash == this->hostHash) { + auto entry = getTransportInfo(Transport::CudaIpc); + void* base; + CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); + data = static_cast(base) + entry.cudaIpcOffsetFromBase; + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data); + } + } +} + +} // namespace mscclpp diff --git a/src/utils.cc b/src/utils.cc index c0766765..6e9e1970 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -5,9 +5,11 @@ ************************************************************************/ #include "utils.h" -#include "core.h" +#include #include + +#include #include // Get current Compute Capability @@ -20,20 +22,17 @@ // return ccMajor*10+ccMinor; // } -mscclppResult_t int64ToBusId(int64_t id, char* busId) -{ +mscclppResult_t int64ToBusId(int64_t id, char* busId) { sprintf(busId, "%04lx:%02lx:%02lx.%01lx", (id) >> 20, (id & 0xff000) >> 12, (id & 0xff0) >> 4, (id & 0xf)); return mscclppSuccess; } -mscclppResult_t busIdToInt64(const char* busId, int64_t* id) -{ - char hexStr[17]; // Longest possible int64 hex string + null terminator. +mscclppResult_t busIdToInt64(const char* busId, int64_t* id) { + char hexStr[17]; // Longest possible int64 hex string + null terminator. int hexOffset = 0; for (int i = 0; hexOffset < sizeof(hexStr) - 1; i++) { char c = busId[i]; - if (c == '.' || c == ':') - continue; + if (c == '.' || c == ':') continue; if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f')) { hexStr[hexOffset++] = busId[i]; } else @@ -45,8 +44,7 @@ mscclppResult_t busIdToInt64(const char* busId, int64_t* id) } // Convert a logical cudaDev index to the NVML device minor number -mscclppResult_t getBusId(int cudaDev, std::string* busId) -{ +mscclppResult_t getBusId(int cudaDev, std::string* busId) { // On most systems, the PCI bus ID comes back as in the 0000:00:00.0 // format. Still need to allocate proper space in case PCI domain goes // higher. @@ -60,8 +58,7 @@ mscclppResult_t getBusId(int cudaDev, std::string* busId) return mscclppSuccess; } -mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode) -{ +mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode) { std::string busId; MSCCLPPCHECK(getBusId(cudaDev, &busId)); @@ -80,21 +77,18 @@ mscclppResult_t getDeviceNumaNode(int cudaDev, int* numaNode) return mscclppSuccess; } -mscclppResult_t getHostName(char* hostname, int maxlen, const char delim) -{ +mscclppResult_t getHostName(char* hostname, int maxlen, const char delim) { if (gethostname(hostname, maxlen) != 0) { strncpy(hostname, "unknown", maxlen); return mscclppSystemError; } int i = 0; - while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1)) - i++; + while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1)) i++; hostname[i] = '\0'; return mscclppSuccess; } -uint64_t getHash(const char* string, int n) -{ +uint64_t getHash(const char* string, int n) { // Based on DJB2a, result = result * 33 ^ char uint64_t result = 5381; for (int c = 0; c < n; c++) { @@ -112,8 +106,7 @@ uint64_t getHash(const char* string, int n) * This string can be overridden by using the MSCCLPP_HOSTID env var. */ #define HOSTID_FILE "/proc/sys/kernel/random/boot_id" -uint64_t getHostHash(void) -{ +uint64_t computeHostHash(void) { char hostHash[1024]; char* hostId; @@ -144,21 +137,24 @@ uint64_t getHostHash(void) return getHash(hostHash, strlen(hostHash)); } +uint64_t getHostHash(void) { + thread_local std::unique_ptr hostHash = std::make_unique(computeHostHash()); + return *hostHash; +} + /* Generate a hash of the unique identifying string for this process * that will be unique for both bare-metal and container instances * Equivalent of a hash of; * * $$ $(readlink /proc/self/ns/pid) */ -uint64_t getPidHash(void) -{ +uint64_t getPidHash(void) { char pname[1024]; // Start off with our pid ($$) sprintf(pname, "%ld", (long)getpid()); int plen = strlen(pname); int len = readlink("/proc/self/ns/pid", pname + plen, sizeof(pname) - 1 - plen); - if (len < 0) - len = 0; + if (len < 0) len = 0; pname[plen + len] = '\0'; TRACE(MSCCLPP_INIT, "unique PID '%s'", pname); @@ -166,10 +162,8 @@ uint64_t getPidHash(void) return getHash(pname, strlen(pname)); } -int parseStringList(const char* string, struct netIf* ifList, int maxList) -{ - if (!string) - return 0; +int parseStringList(const char* string, struct netIf* ifList, int maxList) { + if (!string) return 0; const char* ptr = string; @@ -185,8 +179,7 @@ int parseStringList(const char* string, struct netIf* ifList, int maxList) ifNum++; ifC = 0; } - while (c != ',' && c != '\0') - c = *(++ptr); + while (c != ',' && c != '\0') c = *(++ptr); } else if (c == ',' || c == '\0') { if (ifC > 0) { ifList[ifNum].prefix[ifC] = '\0'; @@ -203,29 +196,22 @@ int parseStringList(const char* string, struct netIf* ifList, int maxList) return ifNum; } -static bool matchIf(const char* string, const char* ref, bool matchExact) -{ +static bool matchIf(const char* string, const char* ref, bool matchExact) { // Make sure to include '\0' in the exact case int matchLen = matchExact ? strlen(string) + 1 : strlen(ref); return strncmp(string, ref, matchLen) == 0; } -static bool matchPort(const int port1, const int port2) -{ - if (port1 == -1) - return true; - if (port2 == -1) - return true; - if (port1 == port2) - return true; +static bool matchPort(const int port1, const int port2) { + if (port1 == -1) return true; + if (port2 == -1) return true; + if (port1 == port2) return true; return false; } -bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) -{ +bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) { // Make an exception for the case where no user list is defined - if (listSize == 0) - return true; + if (listSize == 0) return true; for (int i = 0; i < listSize; i++) { if (matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) { @@ -235,8 +221,7 @@ bool matchIfList(const char* string, int port, struct netIf* ifList, int listSiz return false; } -mscclppResult_t numaBind(int node) -{ +mscclppResult_t numaBind(int node) { int totalNumNumaNodes = numa_num_configured_nodes(); if (node < 0 || node >= totalNumNumaNodes) { WARN("Invalid NUMA node %d, must be between 0 and %d", node, totalNumNumaNodes); @@ -249,9 +234,7 @@ mscclppResult_t numaBind(int node) return mscclppSuccess; } -mscclppResult_t getNumaState(mscclppNumaState* state) -{ - +mscclppResult_t getNumaState(mscclppNumaState* state) { mscclppNumaState state_ = numa_get_run_node_mask(); if (state_ == NULL) { WARN("Failed to get NUMA node mask of the running process"); @@ -261,8 +244,7 @@ mscclppResult_t getNumaState(mscclppNumaState* state) return mscclppSuccess; } -mscclppResult_t setNumaState(mscclppNumaState state) -{ +mscclppResult_t setNumaState(mscclppNumaState state) { if (state == NULL) { WARN("Invalid NUMA state"); return mscclppInvalidUsage; @@ -271,12 +253,8 @@ mscclppResult_t setNumaState(mscclppNumaState state) return mscclppSuccess; } -mscclppTime_t getClock() -{ - return std::chrono::steady_clock::now(); -} +mscclppTime_t getClock() { return std::chrono::steady_clock::now(); } -int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end) -{ +int64_t elapsedClock(mscclppTime_t start, mscclppTime_t end) { return std::chrono::duration_cast(end - start).count(); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 00000000..4ce78e68 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,21 @@ +function(add_test_executable name sources) + add_executable(${name} ${sources}) + target_link_libraries(${name} mscclpp CUDA::cudart CUDA::cuda_driver) + target_include_directories(${name} PRIVATE ${PROJECT_SOURCE_DIR}/src/include) + if(USE_MPI_FOR_TESTS) + target_link_libraries(${name} MPI::MPI_CXX) + target_compile_definitions(${name} PRIVATE MSCCLPP_USE_MPI_FOR_TESTS) + endif() +endfunction() + +add_test_executable(bootstrap_test_cpp bootstrap_test_cpp.cc) +add_test_executable(communicator_test_cpp communicator_test_cpp.cu) +add_test_executable(allgather_test_cpp allgather_test_cpp.cu) +add_test_executable(allgather_test_host_offloading allgather_test_host_offloading.cu) +add_test_executable(ib_test ib_test.cc) + +# Unit tests +add_executable(unit_tests) +target_link_libraries(unit_tests GTest::gtest_main GTest::gmock_main mscclpp) +add_subdirectory(unit) # This adds the sources to the mscclpp target +gtest_discover_tests(unit_tests DISCOVERY_MODE PRE_TEST) diff --git a/tests/allgather_test.cu b/test/allgather_test.cu similarity index 100% rename from tests/allgather_test.cu rename to test/allgather_test.cu diff --git a/test/allgather_test_cpp.cu b/test/allgather_test_cpp.cu new file mode 100644 index 00000000..60652a0f --- /dev/null +++ b/test/allgather_test_cpp.cu @@ -0,0 +1,526 @@ +#include + +#include + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS +#include "mpi.h" +#endif // MSCCLPP_USE_MPI_FOR_TESTS +#include +#include +#include +#include +#include +#include +#include +#include + +static int nranksPerNode = 8; + +// Propagate errors up + +#define MSCCLPPCHECK(call) \ + do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + printf("Failure at %s:%d -> %s\n", __FILE__, __LINE__, mscclppGetErrorString(res)); \ + return res; \ + } \ + } while (0) + +// Check CUDA RT calls +#define CUDACHECK(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + printf("%s:%d Cuda failure '%s'\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// Measure current time in second. +static double getTime(void) +{ + struct timespec tspec; + if (clock_gettime(CLOCK_MONOTONIC, &tspec) == -1) { + printf("clock_gettime failed\n"); + exit(EXIT_FAILURE); + } + return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; +} + +__constant__ mscclpp::channel::SimpleDeviceChannel constDevChans[16]; + +__device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int remoteRank, + size_t nelemsPerGPU) +{ + // this allgather is really simple and implemented as an alltoall + + // this thread's role is a sender role + // put your data asynchronously + if ((threadIdx.x % 32) == 0) + devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + // make sure everyone is put their data before some thread randomly blocks everyone else in signal + __syncthreads(); + // push with flag and sync to make sure the data is received + if ((threadIdx.x % 32) == 0) + devChan.flush(); + + // this thread's role is a receiver role. wait on the semaphore to make sure the data is ready + if ((threadIdx.x % 32) == 0) + devChan.wait(); +} + +__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, + int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) +{ + // this allgather algorithm works as follows: + // Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode + // and waits for data from GPU rank (i-1) % nranksPerNode + // Step 2: GPU rank i sends data to GPU rank (i+2) % nranksPerNode + // ... + // This order is much better for DMA engine for NVLinks + for (int i = 1; i < nranksPerNode; i++) { + if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { + // put your data to GPU (rank+i) % nranksPerNode and signal in one call + if ((threadIdx.x % 32) == 0) + devChan.putWithSignal(offset, size); + } + // wait for the data from GPU (rank-i) % nranksPerNode to arrive + if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { + if ((threadIdx.x % 32) == 0) + devChan.wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory"); + } +} + +__device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, + int remoteRank, size_t nelemsPerGPU) +{ + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + nelemsPerGPU * sizeof(int)); + if (remoteRank / nranksPerNode == rank / nranksPerNode) + if ((threadIdx.x % 32) == 0) + devChan.flush(); +} + +__device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, + int remoteRank, size_t nelemsPerGPU) +{ + // this allgather is a pipelined and hierarchical one and only works for two nodes + // it is implemented as follows: + // Step 1: each node does a local allgather and concurrently, + // local GPU i exchange (piplineSize-1)/pipelineSize portion of their data with + // its cross-node neighbor (local GPU i on the other node) via IB + // Step 2: each node does a local allgather again with the data just received from its + // cross-node neighbor in step 1, and concurrently, exchange the rest of the data with + // its cross-node neighbor + // Step 3: each node does a local allgather for the last time with the rest of the data + + int pipelineSize = 3; + + // Step 1 + // local allgather + if (remoteRank / nranksPerNode == rank / nranksPerNode) { + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + nelemsPerGPU * sizeof(int)); + } + // cross-node exchange + if (remoteRank % nranksPerNode == rank % nranksPerNode) { + // opposite side + if ((threadIdx.x % 32) == 0) + devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); + if ((threadIdx.x % 32) == 0) + devChan.wait(); + } + + __syncthreads(); + + // Step 2 + // local allgather + int otherNghr = (rank + nranksPerNode) % world_size; + if (remoteRank / nranksPerNode == rank / nranksPerNode) { + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int), + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); + } + + // cross-node exchange + if (remoteRank % nranksPerNode == rank % nranksPerNode) { + // opposite side + if ((threadIdx.x % 32) == 0) + devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), + nelemsPerGPU / pipelineSize * sizeof(int)); + if ((threadIdx.x % 32) == 0) + devChan.wait(); + } + + __syncthreads(); + + // Step 3 + // local allgather + if (remoteRank / nranksPerNode == rank / nranksPerNode) { + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, + (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), + nelemsPerGPU / pipelineSize * sizeof(int)); + } + + if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) { + if ((threadIdx.x % 32) == 0) + devChan.flush(); + } +} + +__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) +{ + // find the mapping between remoteRank and devChans + int warpId = threadIdx.x / 32; + int remoteRank = (warpId < rank) ? warpId : warpId + 1; + // Each warp is responsible for one of the remote ranks + mscclpp::channel::SimpleDeviceChannel devChan = constDevChans[warpId]; + + if (kernel == 0) + allgather0(devChan, rank, world_size, remoteRank, nelemsPerGPU); + else if (kernel == 1) + allgather1(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); + else if (kernel == 2) + allgather2(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); +} + +int rankToLocalRank(int rank) +{ + return rank % nranksPerNode; +} + +int rankToNode(int rank) +{ + return rank / nranksPerNode; +} + +void print_usage(const char* prog) +{ +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + printf("usage: %s IP:PORT [rank nranks]\n", prog); +#else + printf("usage: %s IP:PORT rank nranks\n", prog); +#endif +} + +void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSize, size_t nelemsPerGPU, int** data_h, + int** data_d) +{ + CUDACHECK(cudaMalloc(data_d, dataSize)); + CUDACHECK(cudaMemset(*data_d, 0, dataSize)); + + *data_h = new int[nelemsPerGPU * world_size]; + for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { + int val = i + 1; + if (i / nelemsPerGPU == (size_t)rank) { + (*data_h)[i] = val; + } else { + (*data_h)[i] = 0; + } + } + CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice)); +} + +void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, + mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize) +{ + int thisNode = rankToNode(rank); + int cudaNum = rankToLocalRank(rank); + std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); + mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); + std::vector channelIds; + std::vector localMemories; + std::vector> remoteMemories; + + for (int r = 0; r < world_size; ++r) { + if (r == rank) + continue; + mscclpp::Transport transport; + if (rankToNode(r) == thisNode) { + transport = mscclpp::Transport::CudaIpc; + } else { + transport = ibTransport; + } + // Connect with all other ranks + channelIds.push_back(channelService.addChannel(comm.connectOnSetup(r, 0, transport))); + auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); + localMemories.push_back(memory); + comm.sendMemoryOnSetup(memory, r, 0); + remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0)); + } + + comm.setup(); + + std::vector devChannels; + for (size_t i = 0; i < channelIds.size(); ++i) { + devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]), + channelService.addMemory(remoteMemories[i].get()), + channelService.addMemory(localMemories[i]))); + } + + assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel)); + CUDACHECK(cudaMemcpyToSymbol(constDevChans, devChannels.data(), + sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size())); +} + +void printUsage(const char* prog, bool isMpi) +{ + if (isMpi) { + std::string st = "you are using MPI for this test\n"; + st += "two possilbe usages are:\n"; + st += "> " + std::string(prog) + "\n"; + st += "or\n"; + st += "> " + std::string(prog) + " -ip_port [ip:port]\n"; + printf("%s", st.c_str()); + } else { + std::string st = "you are NOT using MPI for this test\n"; + st += "the only possible usage:\n"; + st += "> " + std::string(prog) + " -ip_port [ip:port] -rank [rank] -nranks [nranks]\n"; + printf("%s", st.c_str()); + } +} + +std::unordered_map parseArgs(int argc, const char* argv[], bool isMpi) +{ + std::unordered_map options; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-rankspernode") { + if (isMpi) { + fprintf(stderr, "Error: -rankspernode should not be specified with MPI.\n"); + exit(-1); + } + if (i + 1 < argc) { + options["rankspernode"] = argv[++i]; + } else { + fprintf(stderr, "Error: -rankspernode option requires an argument.\n"); + ; + exit(-1); + } + } else if (arg == "-kernel") { + if (i + 1 < argc) { + options["kernel"] = argv[++i]; + } else { + fprintf(stderr, "Error: -kernel option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-ip_port") { + if (i + 1 < argc) { + options["ip_port"] = argv[++i]; + } else { + fprintf(stderr, "Error: -ip_port option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-rank") { + if (isMpi) { + fprintf(stderr, "Error: -rank should not be specified with MPI.\n"); + exit(-1); + } + if (i + 1 < argc) { + options["rank"] = argv[++i]; + } else { + fprintf(stderr, "Error: -ip_port option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-nranks") { + if (isMpi) { + fprintf(stderr, "Error: -nranks should not be specified with MPI.\n"); + exit(-1); + } + if (i + 1 < argc) { + options["nranks"] = argv[++i]; + } else { + fprintf(stderr, "Error: -ip_port option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-datasize") { + if (i + 1 < argc) { + options["datasize"] = argv[++i]; + } else { + fprintf(stderr, "Error: -datasize option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-help" || arg == "-h") { + printUsage(argv[0], isMpi); + exit(0); + } else { + fprintf(stderr, "Error: Unknown option %s\n", argv[i]); + exit(-1); + } + } + return options; +} + +int main(int argc, const char* argv[]) +{ + bool isMpi = false; +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + isMpi = true; +#endif + + auto parsedArgs = parseArgs(argc, argv, isMpi); + + int rank; + int world_size; +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + MPI_Init(NULL, NULL); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + // get the local number of nodes with MPI + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmrank; + MPI_Comm_size(shmcomm, &shmrank); + nranksPerNode = shmrank; + MPI_Comm_free(&shmcomm); +#else + if (parsedArgs.find("rank") == parsedArgs.end() || parsedArgs.find("nranks") == parsedArgs.end()) { + printUsage(argv[0], isMpi); + exit(-1); + } + rank = std::stoi(parsedArgs["rank"]); + world_size = std::stoi(parsedArgs["nranks"]); + if (parsedArgs.find("rankspernode") == parsedArgs.end()) { + printUsage(argv[0], isMpi); + exit(-1); + } + nranksPerNode = std::stoi(parsedArgs["rankspernode"]); +#endif + int kernelNum = 0; + if (parsedArgs.find("kernel") != parsedArgs.end()) { + kernelNum = std::stoi(parsedArgs["kernel"]); + } + char* ip_port = NULL; + if (parsedArgs.find("ip_port") == parsedArgs.end()) { + printUsage(argv[0], isMpi); + exit(-1); + } + ip_port = (char*)parsedArgs["ip_port"].c_str(); + + int thisNode = rankToNode(rank); + int cudaNum = rankToLocalRank(rank); + CUDACHECK(cudaSetDevice(cudaNum)); + + int* data_d; + int* data_h; + size_t dataSize = 1024 * 1024 * 1024; + if (parsedArgs.find("datasize") != parsedArgs.end()) { + dataSize = std::stoul(parsedArgs["datasize"]); + } + size_t nelemsPerGPU = dataSize / sizeof(int) / world_size; + + try { + if (rank == 0) + printf("Initializing MSCCL++\n"); + auto bootstrapper = std::make_shared(rank, world_size); + bootstrapper->initialize(ip_port); + mscclpp::Communicator comm(bootstrapper); + mscclpp::channel::DeviceChannelService channelService(comm); + + if (rank == 0) + printf("Initializing data for allgather test\n"); + initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d); + + if (rank == 0) + printf("Setting up the connection in MSCCL++\n"); + setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize); + + if (rank == 0) + printf("Launching MSCCL++ proxy threads\n"); + channelService.startProxy(); + + if (rank == 0) + printf("Testing the correctness of AllGather implementation\n"); + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaDeviceSynchronize()); + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + CUDACHECK(cudaDeviceSynchronize()); + CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { + int val = i + 1; + if (data_h[i] != val) { + printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val); + break; + } + } + int tmp[16]; + // A simple barrier + bootstrapper->allGather(tmp, sizeof(int)); + if (rank == 0) + printf("Successfully checked the correctness\n"); + + // Perf test + int iterwithoutcudagraph = 10; + if (rank == 0) + printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); + CUDACHECK(cudaStreamSynchronize(stream)); + bootstrapper->allGather(tmp, sizeof(int)); + for (int i = 0; i < iterwithoutcudagraph; ++i) { + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + } + CUDACHECK(cudaStreamSynchronize(stream)); + bootstrapper->allGather(tmp, sizeof(int)); + + // cudaGraph Capture + int cudagraphiter = 10; + if (rank == 0) + printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter); + cudaGraph_t graph; + cudaGraphExec_t instance; + cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); + for (int i = 0; i < cudagraphiter; ++i) { + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); + } + cudaStreamEndCapture(stream, &graph); + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + + int cudagraphwarmup = 10; + if (rank == 0) + printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup, + cudagraphiter); + for (int i = 0; i < cudagraphwarmup; ++i) { + cudaGraphLaunch(instance, stream); + } + CUDACHECK(cudaStreamSynchronize(stream)); + + // measure runtime + int cudagraphlaunch = 10; + if (rank == 0) + printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, + cudagraphiter); + bootstrapper->allGather(tmp, sizeof(int)); + double t0, t1, ms, time_in_us; + t0 = getTime(); + for (int i = 0; i < cudagraphlaunch; ++i) { + cudaGraphLaunch(instance, stream); + } + CUDACHECK(cudaStreamSynchronize(stream)); + + t1 = getTime(); + ms = (t1 - t0) * 1000.0; + time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter; + printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, + (double)(dataSize) / 1e9 / (time_in_us / 1e6)); + bootstrapper->allGather(tmp, sizeof(int)); + + if (rank == 0) + printf("Stopping MSCCL++ proxy threads\n"); + channelService.stopProxy(); + + } catch (std::exception& e) { + // todo: throw exceptions in the implementation and process them here + } + printf("Rank %d succeeded!\n", rank); + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + MPI_Finalize(); +#endif + return 0; +} diff --git a/test/allgather_test_host_offloading.cu b/test/allgather_test_host_offloading.cu new file mode 100644 index 00000000..c0ced1f0 --- /dev/null +++ b/test/allgather_test_host_offloading.cu @@ -0,0 +1,401 @@ +#include +#include +#include +#include +#include + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS +#include "mpi.h" +#endif // MSCCLPP_USE_MPI_FOR_TESTS +#include +#include +#include +#include +#include +#include + +int nranksPerNode; +int rank; +int world_size; + +// Propagate errors up + +// Check CUDA RT calls +#define CUCHECK(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + printf("%s:%d Cuda failure '%s'\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// Measure current time in second. +static double getTime(void) +{ + struct timespec tspec; + if (clock_gettime(CLOCK_MONOTONIC, &tspec) == -1) { + printf("clock_gettime failed\n"); + exit(EXIT_FAILURE); + } + return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; +} + + +__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo, mscclpp::DeviceEpoch::DeviceHandle* handles, int handleIndex) +{ + int tid = threadIdx.x; + if (tid != r) + handles[tid].epochIncrement(); + __syncthreads(); + // uint64_t tail; + if (tid == 0){ + mscclpp::ProxyTrigger trigger; + trigger.fst = handleIndex; + fifo.push(trigger); + // tail = fifo.push(trigger); + } + if (tid != r) + handles[tid].wait(); + // if (tid == 0) + // while(*(volatile uint64_t*)fifo.tailReplica < tail) {}; +} + +int rankToLocalRank(int rank) +{ + return rank % nranksPerNode; +} + +int rankToNode(int rank) +{ + return rank / nranksPerNode; +} + +void print_usage(const char* prog) +{ +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + printf("usage: %s IP:PORT [rank nranks]\n", prog); +#else + printf("usage: %s IP:PORT rank nranks\n", prog); +#endif +} + +void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSize, size_t nelemsPerGPU, int** data_h, + int** data_d) +{ + CUCHECK(cudaMalloc(data_d, dataSize)); + CUCHECK(cudaMemset(*data_d, 0, dataSize)); + + *data_h = new int[nelemsPerGPU * world_size]; + for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { + int val = i + 1; + if (i / nelemsPerGPU == (size_t)rank) { + (*data_h)[i] = val; + } else { + (*data_h)[i] = 0; + } + } + CUCHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice)); +} + +class MyProxyService { +private: + int deviceNumaNode_; + mscclpp::Proxy proxy_; + std::vector remoteMemories_; + mscclpp::RegisteredMemory localMemory_; + std::vector> hostEpochs_; + std::vector> deviceEpochs1_; + std::vector> deviceEpochs2_; + std::vector> connections_; + int dataSize_; +public: + MyProxyService(mscclpp::Communicator& comm, int* data_d, int dataSize) : remoteMemories_(world_size), connections_(world_size), dataSize_(dataSize), + proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { + int cudaDevice; + CUCHECK(cudaGetDevice(&cudaDevice)); + getDeviceNumaNode(cudaDevice, &deviceNumaNode_); + + int thisNode = rankToNode(rank); + int cudaNum = rankToLocalRank(rank); + std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); + mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); + std::vector> remoteMemoriesFuture(world_size); + + + localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); + for (int r = 0; r < world_size; ++r) { + if (r == rank){ + hostEpochs_.emplace_back(nullptr); + deviceEpochs1_.emplace_back(nullptr); + deviceEpochs2_.emplace_back(nullptr); + continue; + } + mscclpp::Transport transport; + if (rankToNode(r) == thisNode) { + transport = mscclpp::Transport::CudaIpc; + } else { + transport = ibTransport; + } + // Connect with all other ranks + connections_[r] = comm.connectOnSetup(r, 0, transport); + if (rankToNode(r) == thisNode) { + hostEpochs_.emplace_back(nullptr); + } else { + hostEpochs_.emplace_back(std::make_shared(comm, connections_[r])); + } + deviceEpochs1_.emplace_back(std::make_shared(comm, connections_[r])); + deviceEpochs2_.emplace_back(std::make_shared(comm, connections_[r])); + comm.sendMemoryOnSetup(localMemory_, r, 0); + + remoteMemoriesFuture[r] = comm.recvMemoryOnSetup(r, 0); + } + + comm.setup(); + + for (int r = 0; r < world_size; ++r) { + if (r == rank){ + continue; + } + remoteMemories_[r] = remoteMemoriesFuture[r].get(); + } + } + + void bindThread() { + if (deviceNumaNode_ >= 0) { + numaBind(deviceNumaNode_); + } + } + + mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw) { + static int flusher = 0; + if (triggerRaw.fst > 0) { + int dataSizePerRank = dataSize_ / world_size; + for (int r = 1; r < world_size; ++r) { + int nghr = (rank + r) % world_size; + connections_[nghr]->write(remoteMemories_[nghr], rank*dataSizePerRank, localMemory_, rank*dataSizePerRank, dataSizePerRank); + if (triggerRaw.fst == 1) + deviceEpochs1_[nghr]->signal(); + else + deviceEpochs2_[nghr]->signal(); + if ((flusher % 64) == 0 && mscclpp::AllIBTransports.has(connections_[nghr]->transport())){ + // if we are using IB transport, we need a flush every once in a while + connections_[nghr]->flush(); + } + } + flusher++; + + } + return mscclpp::ProxyHandlerResult::FlushFifoTailAndContinue; + } + + void start(){ + proxy_.start(); + } + + void stop(){ + proxy_.stop(); + } + + mscclpp::HostProxyFifo& fifo(){ + return proxy_.fifo(); + } + + mscclpp::DeviceEpoch::DeviceHandle getDeviceHandle1(int r){ + return deviceEpochs1_[r]->deviceHandle(); + } + + mscclpp::DeviceEpoch::DeviceHandle getDeviceHandle2(int r){ + return deviceEpochs2_[r]->deviceHandle(); + } +}; + +std::unordered_map parseArgs(int argc, char* argv[]) +{ + std::unordered_map options; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-datasize") { + if (i + 1 < argc) { + options["datasize"] = argv[++i]; + } else { + fprintf(stderr, "Error: -datasize option requires an argument.\n"); + exit(-1); + } + } else if (arg == "-help" || arg == "-h") { + exit(0); + } else { + fprintf(stderr, "Error: Unknown option %s\n", argv[i]); + exit(-1); + } + } + return options; +} + + +int main(int argc, char* argv[]) +{ + // sleep(10); + MPI_Init(&argc, &argv); + auto parsedArgs = parseArgs(argc, argv); + + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + // get the local number of nodes with MPI + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmrank; + MPI_Comm_size(shmcomm, &shmrank); + nranksPerNode = shmrank; + MPI_Comm_free(&shmcomm); + + + int cudaNum = rankToLocalRank(rank); + CUCHECK(cudaSetDevice(cudaNum)); + + if (rank == 0) + printf("Initializing MSCCL++\n"); + auto bootstrap = std::make_shared(rank, world_size); + mscclpp::UniqueId uniqueId; + if (rank == 0) + uniqueId = bootstrap->createUniqueId(); + MPI_Bcast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(uniqueId); + mscclpp::Communicator comm(bootstrap); + + int* data_d; + int* data_h; + size_t dataSize = 1024 * 1024 * 1024; + if (parsedArgs.find("datasize") != parsedArgs.end()) { + dataSize = std::stoul(parsedArgs["datasize"]); + } + size_t nelemsPerGPU = dataSize / sizeof(int) / world_size; + + if (rank == 0) + printf("Initializing data for allgather test\n"); + initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d); + + if (rank == 0) + printf("Setting up the connection in MSCCL++\n"); + + MyProxyService proxyService(comm, data_d, dataSize); + // setupProxyService(comm, proxyService, data_d, dataSize); + + if (rank == 0) + printf("Launching MSCCL++ proxy threads\n"); + proxyService.start(); + mscclpp::DeviceProxyFifo fifo = proxyService.fifo().deviceFifo(); + if (rank == 0) + printf("Testing the correctness of AllGather implementation\n"); + cudaStream_t stream; + CUCHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + mscclpp::DeviceEpoch::DeviceHandle* deviceHandles1; + mscclpp::DeviceEpoch::DeviceHandle* deviceHandles2; + + CUCHECK(cudaMalloc(&deviceHandles1, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * world_size)); + for (int i = 0; i < world_size; ++i) { + if (i == rank) + continue; + auto handle = proxyService.getDeviceHandle1(i); + CUCHECK(cudaMemcpy(&deviceHandles1[i], &handle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), cudaMemcpyHostToDevice)); + } + + CUCHECK(cudaMalloc(&deviceHandles2, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * world_size)); + for (int i = 0; i < world_size; ++i) { + if (i == rank) + continue; + auto handle = proxyService.getDeviceHandle2(i); + CUCHECK(cudaMemcpy(&deviceHandles2[i], &handle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), cudaMemcpyHostToDevice)); + } + + kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1); + CUCHECK(cudaStreamSynchronize(stream)); + + CUCHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < nelemsPerGPU * world_size; i++) { + int val = i + 1; + if (data_h[i] != val) { + printf("oh uh! data_h[%ld] (%d) != val (%d)\n", i, data_h[i], val); + break; + } + } + + bootstrap->barrier(); + if (rank == 0) + printf("Correctness test passed!\n"); + + double t0, t1, ms, time_in_us; + int iterwithoutcudagraph = 10; + if (rank == 0) + printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); + CUCHECK(cudaStreamSynchronize(stream)); + bootstrap->barrier(); + t0 = getTime(); + for (int i = 0; i < iterwithoutcudagraph; ++i) { + kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1); + kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles2, 2); + } + CUCHECK(cudaStreamSynchronize(stream)); + bootstrap->barrier(); + t1 = getTime(); + ms = (t1 - t0) * 1000.0; + time_in_us = ms * 1000. / (float)iterwithoutcudagraph / 2; + printf("No Graph %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, + (double)(dataSize) / 1e9 / (time_in_us / 1e6)); + + // cudaGraph Capture + int cudagraphiter = 10; + if (rank == 0) + printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter); + cudaGraph_t graph; + cudaGraphExec_t instance; + cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); + for (int i = 0; i < cudagraphiter; ++i) { + kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1); + kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles2, 2); + } + cudaStreamEndCapture(stream, &graph); + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + + int cudagraphwarmup = 10; + if (rank == 0) + printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup, + cudagraphiter); + for (int i = 0; i < cudagraphwarmup; ++i) { + cudaGraphLaunch(instance, stream); + } + CUCHECK(cudaStreamSynchronize(stream)); + + // measure runtime + int cudagraphlaunch = 10; + if (rank == 0) + printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, + cudagraphiter); + bootstrap->barrier(); + t0 = getTime(); + for (int i = 0; i < cudagraphlaunch; ++i) { + cudaGraphLaunch(instance, stream); + } + CUCHECK(cudaStreamSynchronize(stream)); + + t1 = getTime(); + ms = (t1 - t0) * 1000.0; + time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter / 2; + if (rank == 0) + printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, + (double)(dataSize) / 1e9 / (time_in_us / 1e6)); + bootstrap->barrier(); + + if (rank == 0) + printf("Stopping MSCCL++ proxy threads\n"); + proxyService.stop(); + + + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + MPI_Finalize(); +#endif + return 0; +} diff --git a/tests/allgather_test_standalone.cu b/test/allgather_test_standalone.cu similarity index 100% rename from tests/allgather_test_standalone.cu rename to test/allgather_test_standalone.cu diff --git a/tests/allreduce_test.cu b/test/allreduce_test.cu similarity index 100% rename from tests/allreduce_test.cu rename to test/allreduce_test.cu diff --git a/tests/bootstrap_test.cc b/test/bootstrap_test.cc similarity index 100% rename from tests/bootstrap_test.cc rename to test/bootstrap_test.cc diff --git a/test/bootstrap_test_cpp.cc b/test/bootstrap_test_cpp.cc new file mode 100644 index 00000000..b32d83fa --- /dev/null +++ b/test/bootstrap_test_cpp.cc @@ -0,0 +1,152 @@ +#include + +#include +#include +#include +#include + +void test_allgather(std::shared_ptr bootstrap) +{ + std::vector tmp(bootstrap->getNranks(), 0); + tmp[bootstrap->getRank()] = bootstrap->getRank() + 1; + bootstrap->allGather(tmp.data(), sizeof(int)); + for (int i = 0; i < bootstrap->getNranks(); i++) { + assert(tmp[i] == i + 1); + } + if (bootstrap->getRank() == 0) + std::cout << "AllGather test passed!" << std::endl; +} + +void test_barrier(std::shared_ptr bootstrap) +{ + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Barrier test passed!" << std::endl; +} + +void test_sendrecv(std::shared_ptr bootstrap) +{ + for (int i = 0; i < bootstrap->getNranks(); i++) { + if (bootstrap->getRank() == i) + continue; + int msg1 = (bootstrap->getRank() + 1) * 3; + int msg2 = (bootstrap->getRank() + 1) * 3 + 1; + int msg3 = (bootstrap->getRank() + 1) * 3 + 2; + bootstrap->send(&msg1, sizeof(int), i, 0); + bootstrap->send(&msg2, sizeof(int), i, 1); + bootstrap->send(&msg3, sizeof(int), i, 2); + } + + for (int i = 0; i < bootstrap->getNranks(); i++) { + if (bootstrap->getRank() == i) + continue; + int msg1 = 0; + int msg2 = 0; + int msg3 = 0; + // recv them in the opposite order to check correctness + bootstrap->recv(&msg2, sizeof(int), i, 1); + bootstrap->recv(&msg3, sizeof(int), i, 2); + bootstrap->recv(&msg1, sizeof(int), i, 0); + assert(msg1 == (i + 1) * 3); + assert(msg2 == (i + 1) * 3 + 1); + assert(msg3 == (i + 1) * 3 + 2); + } + if (bootstrap->getRank() == 0) + std::cout << "Send/Recv test passed!" << std::endl; +} + +void test_all(std::shared_ptr bootstrap) +{ + test_allgather(bootstrap); + test_barrier(bootstrap); + test_sendrecv(bootstrap); +} + +void test_mscclpp_bootstrap_with_id(int rank, int worldSize) +{ + auto bootstrap = std::make_shared(rank, worldSize); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl; +} + +void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar) +{ + std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); + bootstrap->initialize(ipPortPiar); + + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl; +} + +class MPIBootstrap : public mscclpp::BaseBootstrap +{ +public: + MPIBootstrap() : BaseBootstrap() + { + } + int getRank() override + { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + return rank; + } + int getNranks() override + { + int worldSize; + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + return worldSize; + } + void allGather(void* sendbuf, int size) override + { + MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); + } + void barrier() override + { + MPI_Barrier(MPI_COMM_WORLD); + } + void send(void* sendbuf, int size, int dest, int tag) override + { + MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD); + } + void recv(void* recvbuf, int size, int source, int tag) override + { + MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + } +}; + +void test_mpi_bootstrap() +{ + std::shared_ptr bootstrap(new MPIBootstrap()); + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MPI Bootstrap test passed! ---" << std::endl; +} + +int main(int argc, char** argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + if (argc > 2) { + if (rank == 0) + std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl; + MPI_Finalize(); + return 0; + } + test_mscclpp_bootstrap_with_id(rank, worldSize); + if (argc == 2) + test_mscclpp_bootstrap_with_ip_port_pair(rank, worldSize, argv[1]); + test_mpi_bootstrap(); + + MPI_Finalize(); + return 0; +} \ No newline at end of file diff --git a/tests/common.cu b/test/common.cu similarity index 100% rename from tests/common.cu rename to test/common.cu diff --git a/tests/common.h b/test/common.h similarity index 100% rename from tests/common.h rename to test/common.h diff --git a/test/communicator_test_cpp.cu b/test/communicator_test_cpp.cu new file mode 100644 index 00000000..cda4d712 --- /dev/null +++ b/test/communicator_test_cpp.cu @@ -0,0 +1,366 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#define CUDATHROW(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ + } \ + } while (false) + +mscclpp::Transport findIb(int localRank) +{ + mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, + mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, + mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + return IBs[localRank]; +} + +void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, + size_t deviceBufferSize, mscclpp::Transport myIbDevice, + mscclpp::RegisteredMemory& localMemory, + std::unordered_map& remoteMemory) +{ + localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); + std::unordered_map> futureRemoteMemory; + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + communicator.sendMemoryOnSetup(localMemory, i, 0); + futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0); + } + } + communicator.setup(); + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + remoteMemory[i] = futureRemoteMemory[i].get(); + } + } +} + +void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, + mscclpp::Transport myIbDevice, + std::unordered_map>& connections) +{ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + if (i / nRanksPerNode == rank / nRanksPerNode) { + connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); + } else { + connections[i] = communicator.connectOnSetup(i, 0, myIbDevice); + } + } + } + communicator.setup(); +} + +void write_remote(int rank, int worldSize, std::unordered_map>& connections, + std::unordered_map& remoteRegisteredMemories, + mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank) +{ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + auto& conn = connections.at(i); + auto& peerMemory = remoteRegisteredMemories.at(i); + conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, + rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int)); + conn->flush(); + } + } +} + +void device_buffer_init(int rank, int worldSize, int dataCount, std::vector& devicePtr) +{ + for (int n = 0; n < (int)devicePtr.size(); n++) { + std::vector hostBuffer(dataCount, 0); + for (int i = 0; i < dataCount; i++) { + hostBuffer[i] = rank + n * worldSize; + } + CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount * sizeof(int), cudaMemcpyHostToDevice)); + } + CUDATHROW(cudaDeviceSynchronize()); +} + +bool test_device_buffer_write_correctness(int rank, int worldSize, int nRanksPerNode, int dataCount, + std::vector& devicePtr, bool skipLocal = false) +{ + for (int n = 0; n < (int)devicePtr.size(); n++) { + std::vector hostBuffer(dataCount, 0); + CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount * sizeof(int), cudaMemcpyDeviceToHost)); + for (int i = 0; i < worldSize; i++) { + if (i / nRanksPerNode == rank / nRanksPerNode && skipLocal) { + continue; + } + for (int j = i * dataCount / worldSize; j < (i + 1) * dataCount / worldSize; j++) { + if (hostBuffer[j] != i + n * worldSize) { + return false; + } + } + } + } + return true; +} + +void test_write(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize, + std::shared_ptr bootstrap, + std::unordered_map>& connections, + std::vector>& remoteMemory, + std::vector& localMemory, std::vector& devicePtr, int numBuffers) +{ + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t dataCount = deviceBufferSize / sizeof(int); + + device_buffer_init(rank, worldSize, dataCount, devicePtr); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + for (int n = 0; n < numBuffers; n++) { + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); + } + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + // polling until it becomes ready + bool ready = false; + int niter = 0; + do { + ready = test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr); + niter++; + if (niter == 10000) { + throw std::runtime_error("Polling is stuck."); + } + } while (!ready); + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + if (bootstrap->getRank() == 0) + std::cout << "--- Testing vanialla writes passed ---" << std::endl; +} + +__global__ void increament_epochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize) +{ + int tid = threadIdx.x; + if (tid != rank && tid < worldSize) { + deviceEpochs[tid].epochIncrement(); + } +} + +__global__ void wait_epochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize) +{ + int tid = threadIdx.x; + if (tid != rank && tid < worldSize) { + deviceEpochs[tid].wait(); + } +} + +void test_write_with_device_epochs(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize, + mscclpp::Communicator& communicator, + std::shared_ptr bootstrap, + std::unordered_map>& connections, + std::vector>& remoteMemory, + std::vector& localMemory, std::vector& devicePtr, + int numBuffers) +{ + + std::unordered_map> epochs; + for (auto entry : connections) { + auto& conn = entry.second; + epochs.insert({entry.first, std::make_shared(communicator, conn)}); + } + communicator.setup(); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Epochs are created" << std::endl; + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t dataCount = deviceBufferSize / sizeof(int); + + device_buffer_init(rank, worldSize, dataCount, devicePtr); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + mscclpp::DeviceEpoch::DeviceHandle* deviceEpochHandles; + CUDATHROW(cudaMalloc(&deviceEpochHandles, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * worldSize)); + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + mscclpp::DeviceEpoch::DeviceHandle deviceHandle = epochs[i]->deviceHandle(); + CUDATHROW(cudaMemcpy(&deviceEpochHandles[i], &deviceHandle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), + cudaMemcpyHostToDevice)); + } + } + CUDATHROW(cudaDeviceSynchronize()); + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA device epochs are created" << std::endl; + + for (int n = 0; n < numBuffers; n++) { + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); + } + + increament_epochs<<<1, worldSize>>>(deviceEpochHandles, rank, worldSize); + CUDATHROW(cudaDeviceSynchronize()); + + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + epochs[i]->signal(); + } + } + + wait_epochs<<<1, worldSize>>>(deviceEpochHandles, rank, worldSize); + CUDATHROW(cudaDeviceSynchronize()); + + if (!test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr)) { + throw std::runtime_error("unexpected result."); + } + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "--- Testing writes with device epochs for " << std::to_string(numBuffers) << " buffers passed ---" + << std::endl; +} + +void test_write_with_host_epochs(int rank, int worldSize, int nRanksPerNode, int deviceBufferSize, + mscclpp::Communicator& communicator, std::shared_ptr bootstrap, + std::unordered_map>& connections, + std::vector>& remoteMemory, + std::vector& localMemory, std::vector& devicePtr, + int numBuffers) +{ + + std::unordered_map> epochs; + for (auto entry : connections) { + auto& conn = entry.second; + if (conn->transport() == mscclpp::Transport::CudaIpc) + continue; + epochs.insert({entry.first, std::make_shared(communicator, conn)}); + } + communicator.setup(); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Epochs are created" << std::endl; + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t dataCount = deviceBufferSize / sizeof(int); + + device_buffer_init(rank, worldSize, dataCount, devicePtr); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Host epochs are created" << std::endl; + + for (int n = 0; n < numBuffers; n++) { + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); + } + + for (int i = 0; i < worldSize; i++) { + if (i != rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) { + epochs[i]->increamentAndSignal(); + } + } + + for (int i = 0; i < worldSize; i++) { + if (i != rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) { + epochs[i]->wait(); + } + } + + if (!test_device_buffer_write_correctness(rank, worldSize, nRanksPerNode, dataCount, devicePtr, true)) { + throw std::runtime_error("unexpected result."); + } + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "--- Testing writes with host epochs for " << std::to_string(numBuffers) << " buffers passed ---" + << std::endl; +} + +void test_communicator(int rank, int worldSize, int nRanksPerNode) +{ + auto bootstrap = std::make_shared(rank, worldSize); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + mscclpp::Communicator communicator(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "Communicator initialization passed" << std::endl; + + std::unordered_map> connections; + auto myIbDevice = findIb(rank % nRanksPerNode); + + make_connections(communicator, rank, worldSize, nRanksPerNode, myIbDevice, connections); + if (bootstrap->getRank() == 0) + std::cout << "Connection setup passed" << std::endl; + + int numBuffers = 10; + std::vector devicePtr(numBuffers); + int deviceBufferSize = 1024 * 1024; + + std::vector localMemory(numBuffers); + std::vector> remoteMemory(numBuffers); + + for (int n = 0; n < numBuffers; n++) { + if (n % 100 == 0) + std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl; + CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize)); + register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], + remoteMemory[n]); + } + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + test_write(rank, worldSize, nRanksPerNode, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, + devicePtr, numBuffers); + + test_write_with_device_epochs(rank, worldSize, nRanksPerNode, deviceBufferSize, communicator, bootstrap, connections, + remoteMemory, localMemory, devicePtr, numBuffers); + + test_write_with_host_epochs(rank, worldSize, nRanksPerNode, deviceBufferSize, communicator, bootstrap, connections, + remoteMemory, localMemory, devicePtr, numBuffers); + + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; + + for (int n = 0; n < numBuffers; n++) { + CUDATHROW(cudaFree(devicePtr[n])); + } +} + +int main(int argc, char** argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmWorldSize; + MPI_Comm_size(shmcomm, &shmWorldSize); + int nRanksPerNode = shmWorldSize; + MPI_Comm_free(&shmcomm); + + test_communicator(rank, worldSize, nRanksPerNode); + + MPI_Finalize(); + return 0; +} diff --git a/tests/unittests/ib_test.cc b/test/ib_test.cc similarity index 53% rename from tests/unittests/ib_test.cc rename to test/ib_test.cc index 2c194eaf..753d6fa4 100644 --- a/tests/unittests/ib_test.cc +++ b/test/ib_test.cc @@ -1,7 +1,9 @@ #include "alloc.h" #include "checks.h" -#include "ib.h" -#include +#include "ib.hpp" +#include "infiniband/verbs.h" +#include +#include #include // Measure current time in second. @@ -24,8 +26,8 @@ int main(int argc, const char* argv[]) printf("Usage: %s <0(recv)/1(send)> \n", argv[0]); return 1; } - const char* ip_port = argv[1]; - int is_send = atoi(argv[2]); + const char* ipPortPair = argv[1]; + int isSend = atoi(argv[2]); int cudaDevId = atoi(argv[3]); std::string ibDevName = "mlx5_ib" + std::string(argv[4]); @@ -35,51 +37,40 @@ int main(int argc, const char* argv[]) int nelem = 1; MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem)); - mscclppComm_t comm; - MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send)); + std::shared_ptr bootstrap(new mscclpp::Bootstrap(isSend, 2)); + bootstrap->initialize(ipPortPair); - struct mscclppIbContext* ctx; - struct mscclppIbQp* qp; - struct mscclppIbMr* mr; - MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str())); - MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr)); + mscclpp::IbCtx ctx(ibDevName); + mscclpp::IbQp* qp = ctx.createQp(); + const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem); - struct mscclppIbQpInfo* qpInfo; - MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2)); - qpInfo[is_send] = qp->info; + std::array qpInfo; + qpInfo[isSend] = qp->getInfo(); - struct mscclppIbMrInfo* mrInfo; - MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2)); - mrInfo[is_send] = mr->info; + std::array mrInfo; + mrInfo[isSend] = mr->getInfo(); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo))); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo))); + bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - for (int i = 0; i < 2; ++i) { - if (i == is_send) + for (int i = 0; i < bootstrap->getNranks(); ++i) { + if (i == isSend) continue; - qp->rtr(&qpInfo[i]); + qp->rtr(qpInfo[i]); qp->rts(); break; } printf("connection succeed\n"); - // A simple barrier - int* tmp; - MSCCLPPCHECK(mscclppCalloc(&tmp, 2)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); + bootstrap->barrier(); - if (is_send) { + if (isSend) { int maxIter = 100000; double start = getTime(); for (int iter = 0; iter < maxIter; ++iter) { - qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); - if (qp->postSend() != 0) { - WARN("postSend failed"); - return 1; - } + qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); + qp->postSend(); bool waiting = true; while (waiting) { int wcNum = qp->pollCq(); @@ -88,7 +79,7 @@ int main(int argc, const char* argv[]) return 1; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &qp->wcs[i]; + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); return 1; @@ -103,10 +94,7 @@ int main(int argc, const char* argv[]) } // A simple barrier - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - - MSCCLPPCHECK(mscclppIbContextDestroy(ctx)); - MSCCLPPCHECK(mscclppCommDestroy(comm)); + bootstrap->barrier(); return 0; } diff --git a/tests/p2p_test.cu b/test/p2p_test.cu similarity index 98% rename from tests/p2p_test.cu rename to test/p2p_test.cu index 95f18e6c..e2218e83 100644 --- a/tests/p2p_test.cu +++ b/test/p2p_test.cu @@ -54,7 +54,7 @@ __global__ void kernel(int rank, int world_size) volatile int* data = (volatile int*)devConn.localBuff; volatile uint64_t* localFlag = devConn.localFlag; #if (USE_DMA_FOR_P2P == 0) - volatile uint64_t* remoteFlag = devConn.remoteFlag; + volatile uint64_t* remoteSignalEpochId = devConn.remoteSignalEpochId; #endif volatile uint64_t* proxyFlag = devConn.proxyFlag; @@ -106,7 +106,7 @@ __global__ void kernel(int rank, int world_size) volatile int* remoteData = (volatile int*)devConn.remoteBuff; // Wait until the remote data is set - while (*remoteFlag == baseFlag) { + while (*remoteSignalEpochId == baseFlag) { } // Read remote data diff --git a/tests/sendrecv_test.cu b/test/sendrecv_test.cu similarity index 100% rename from tests/sendrecv_test.cu rename to test/sendrecv_test.cu diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt new file mode 100644 index 00000000..a5b2f4b9 --- /dev/null +++ b/test/unit/CMakeLists.txt @@ -0,0 +1,3 @@ +target_sources(unit_tests PRIVATE + core_tests.cc +) diff --git a/test/unit/core_tests.cc b/test/unit/core_tests.cc new file mode 100644 index 00000000..e3bf7265 --- /dev/null +++ b/test/unit/core_tests.cc @@ -0,0 +1,49 @@ +#include +#include +#include + +class LocalCommunicatorTest : public ::testing::Test { + protected: + void SetUp() override { + bootstrap = std::make_shared(0, 1); + comm = std::make_shared(bootstrap); + } + + std::shared_ptr bootstrap; + std::shared_ptr comm; +}; + +class MockSetuppable : public mscclpp::Setuppable { + public: + MOCK_METHOD(void, beginSetup, (std::shared_ptr bootstrap), (override)); + MOCK_METHOD(void, endSetup, (std::shared_ptr bootstrap), (override)); +}; + +TEST_F(LocalCommunicatorTest, OnSetup) { + auto mockSetuppable = std::make_shared(); + comm->onSetup(mockSetuppable); + EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast(bootstrap))); + EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast(bootstrap))); + comm->setup(); +} + +TEST_F(LocalCommunicatorTest, RegisterMemory) { + int dummy[42]; + auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports); + EXPECT_EQ(memory.data(), &dummy); + EXPECT_EQ(memory.size(), sizeof(dummy)); + EXPECT_EQ(memory.rank(), 0); + EXPECT_EQ(memory.transports(), mscclpp::NoTransports); +} + +TEST_F(LocalCommunicatorTest, SendMemoryToSelf) { + int dummy[42]; + auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports); + comm->sendMemoryOnSetup(memory, 0, 0); + auto memoryFuture = comm->recvMemoryOnSetup(0, 0); + comm->setup(); + auto sameMemory = memoryFuture.get(); + EXPECT_EQ(sameMemory.size(), memory.size()); + EXPECT_EQ(sameMemory.rank(), memory.rank()); + EXPECT_EQ(sameMemory.transports(), memory.transports()); +}