diff --git a/python/benchmark/allreduce.cu b/python/benchmark/allreduce.cu index 4dc2b0c4..e90dc147 100644 --- a/python/benchmark/allreduce.cu +++ b/python/benchmark/allreduce.cu @@ -118,12 +118,9 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) { // AllReduce1 // ------------------------------------------- -#ifndef READ_ONLY -#define READ_ONLY 0 -#endif - -extern "C" __global__ void __launch_bounds__(1024, 1) - allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, size_t nelems) { +template +__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, + size_t nelems) { const size_t chunkSize = nelems / nranks; if (nranks == 1) return; const int nPeer = nranks - 1; @@ -211,13 +208,21 @@ extern "C" __global__ void __launch_bounds__(1024, 1) } } +extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, + int rank, int nranks, size_t nelems, int read_only) { + if (read_only) + allreduce1_helper<1>(smChans, buff, rank, nranks, nelems); + else + allreduce1_helper<0>(smChans, buff, rank, nranks, nelems); +} + // ------------------------------------------- // AllReduce2 // ------------------------------------------- __device__ uint64_t globalFlag = 1; -extern "C" __global__ void __launch_bounds__(512, 1) +extern "C" __global__ void __launch_bounds__(1024, 1) allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank, int worldSize, size_t nelems) { nelems = nelems / (sizeof(int) / sizeof(TYPE)); diff --git a/python/benchmark/allreduce_bench.py b/python/benchmark/allreduce_bench.py index aa2c096e..2cf09cba 100644 --- a/python/benchmark/allreduce_bench.py +++ b/python/benchmark/allreduce_bench.py @@ -23,6 +23,46 @@ else: raise RuntimeError("Unknown data type") +def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups): + import matplotlib.pyplot as plt + + human_readable_sizes = [human_readable_size(size) for size in sizes] + + fig, ax1 = plt.subplots(figsize=(10, 6)) + + # Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis + (line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW") + (line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW") + ax1.set_ylabel("AlgBW (GB/s)") + ax1.set_xlabel("Data Size") + + # Logarithmic x-axis + ax1.set_xscale("log", base=2) + ax1.set_xticks(sizes) + ax1.set_xticklabels(human_readable_sizes, rotation=45) + + # Adding secondary y-axis for Speed Up + ax2 = ax1.twinx() + (line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up") + ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green") + ax2.tick_params(axis="y", labelcolor="green") + + # Set the lower bound of the secondary y-axis to 0 + ax2.set_ylim(bottom=0) + + # Creating legends + lines = [line1, line2, line3] + labels = [line.get_label() for line in lines] + ax1.legend(lines, labels, loc="upper left") + + # Setting title and grid + ax1.set_title("MSCCLPP vs NCCL -- " + str(MPI.COMM_WORLD.size // N_GPUS_PER_NODE) + " Nodes") + ax2.grid(True, which="both", ls="--") + + # Saving the plot + plt.savefig("mscclpp_vs_nccl_comparison.pdf", format="pdf") + + def human_readable_size(size, decimal_places=1): for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]: if size < 1024.0 or unit == "PiB": @@ -99,15 +139,12 @@ def run_benchmark( memory_out = cp.zeros(nelem, dtype=data_type) cp.cuda.runtime.deviceSynchronize() + proxy_service = None if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1: if memory.nbytes < 2**20: mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out) elif memory.nbytes < 2**29: - if memory.nbytes >= 2**20 and memory.nbytes <= 2**22: - read_only = 0 - else: - read_only = 1 - mscclpp_call = MscclppAllReduce1(mscclpp_group, memory, read_only=read_only) + mscclpp_call = MscclppAllReduce1(mscclpp_group, memory) else: proxy_service = ProxyService() mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service) @@ -117,14 +154,13 @@ def run_benchmark( proxy_service = ProxyService() mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service) proxy_service.start_proxy() - best_config = find_best_config(mscclpp_call, 100) - mscclpp_call.set_params(*best_config) else: proxy_service = ProxyService() mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service) proxy_service.start_proxy() - best_config = find_best_config(mscclpp_call, 20) - mscclpp_call.set_params(*best_config) + + best_config = find_best_config(mscclpp_call, 20) + mscclpp_call.set_params(*best_config) nccl_call = NcclAllReduce(nccl_op, memory) @@ -145,6 +181,7 @@ def run_benchmark( MPI.COMM_WORLD.barrier() proxy_service.stop_proxy() + speed_up = nccl_time / mscclpp_time if MPI.COMM_WORLD.rank == 0: table.add_row( [ @@ -155,12 +192,14 @@ def run_benchmark( "{:.2f}".format(nccl_time), "{:.2f}".format(nccl_algBw), nccl_check, - "{:.2f}".format(nccl_time / mscclpp_time), + "{:.2f}".format(speed_up), ] ) if MPI.COMM_WORLD.rank == 0: print(".", end="", flush=True) + return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up + if __name__ == "__main__": shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL) @@ -200,16 +239,29 @@ if __name__ == "__main__": "Speed Up", ] - for i in range(10, 28): + sizes = [] + mscclpp_algbw = [] + nccl_algbw = [] + speed_ups = [] + for i in range(10, 30): if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1: - run_benchmark(mscclpp_group, nccl_comm, table, 100, 2**i) + nelems = 2**i elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2: - run_benchmark(mscclpp_group, nccl_comm, table, 100, 3 * 2**i) + nelems = 3 * 2**i else: raise RuntimeError("Only support one node/two nodes communication") + size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems) + sizes.append(size) + mscclpp_algbw.append(mscclpp_algBw) + nccl_algbw.append(nccl_algBw) + speed_ups.append(speed_up) + if MPI.COMM_WORLD.rank == 0: print() print(table) + + plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups) + mscclpp_group = None nccl_comm = None diff --git a/python/benchmark/mscclpp_op.py b/python/benchmark/mscclpp_op.py index 92d17e2f..ab51f7c8 100644 --- a/python/benchmark/mscclpp_op.py +++ b/python/benchmark/mscclpp_op.py @@ -35,7 +35,7 @@ class MscclppAllReduce1: group: mscclpp_comm.CommGroup, memory: cp.ndarray, read_only: int = 1, - nthreads: int = 1024, + block_size: int = 1024, nblocks: int = 24, ): self.group = group @@ -55,30 +55,55 @@ class MscclppAllReduce1: file="allreduce.cu", kernel_name="allreduce1", file_dir=file_dir, - macro_dict={"TYPE": type_str, "READ_ONLY": str(read_only)}, + macro_dict={"TYPE": type_str}, ).get_compiled_kernel() - self.params = b"" self.device_handles = [] for rank in range(self.group.nranks): if rank != self.group.my_rank: self.device_handles.append(self.sm_channels[rank].device_handle().raw) + + self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8) + + self.set_params(nblocks, block_size, read_only) + + def __call__(self, stream_ptr): + self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr) + return self.memory + + def set_params(self, nblocks, block_size, read_only): + self.nblocks = nblocks + self.block_size = block_size + self.read_only = read_only + self.params = b"" self.params += pack( - cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8), + self.device_handles_cp, self.memory, self.group.my_rank, self.group.nranks, ctypes.c_size_t(self.memory.size), + self.read_only, ) - self.nthreads = nthreads - self.nblocks = nblocks - def __call__(self, stream_ptr): - self.kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, stream_ptr) - return self.memory + def auto_tune(self): + nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108] + block_size_to_try = [256, 512, 1024] + read_only_to_try = [0, 1] + for nblocks in nblocks_to_try: + for block_size in block_size_to_try: + for read_only in read_only_to_try: + self.set_params(nblocks, block_size, read_only) + yield nblocks, block_size, read_only class MscclppAllReduce2: - def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, memory_out: cp.ndarray): + def __init__( + self, + group: mscclpp_comm.CommGroup, + memory: cp.ndarray, + memory_out: cp.ndarray, + block_size: int = 512, + nblocks: int = 21, + ): self.group = group self.memory = memory self.memory_out = memory_out @@ -97,13 +122,26 @@ class MscclppAllReduce2: self.kernel = KernelBuilder( file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str} ).get_compiled_kernel() - self.params = b"" self.device_handles = [] for rank in range(self.group.nranks): if rank != self.group.my_rank: self.device_handles.append(self.sm_channels[rank].device_handle().raw) + + self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8) + + self.set_params(nblocks, block_size) + + def __call__(self, stream_ptr): + self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr) + return self.memory_out + + def set_params(self, nblocks, block_size): + self.nblocks = nblocks + self.block_size = block_size + + self.params = b"" self.params += pack( - cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8), + self.device_handles_cp, self.memory, self.scratch, self.memory_out, @@ -112,13 +150,24 @@ class MscclppAllReduce2: ctypes.c_size_t(self.memory.size), ) - def __call__(self, stream_ptr): - self.kernel.launch_kernel(self.params, 21, 512, 0, stream_ptr) - return self.memory_out + def auto_tune(self): + nblocks_to_try = [21, 42, 63, 84, 105] + block_size_to_try = [256, 512, 1024] + for nblocks in nblocks_to_try: + for block_size in block_size_to_try: + self.set_params(nblocks, block_size) + yield nblocks, block_size class MscclppAllReduce3: - def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, proxy_service: ProxyService): + def __init__( + self, + group: mscclpp_comm.CommGroup, + memory: cp.ndarray, + proxy_service: ProxyService, + block_size: int = 1024, + nblocks: int = 24, + ): self.group = group self.memory = memory remote_nghrs = list(range(self.group.nranks)) @@ -141,16 +190,28 @@ class MscclppAllReduce3: self.kernel = KernelBuilder( file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str} ).get_compiled_kernel() - self.params = b"" self.fst_device_handles = [] self.snd_device_handles = [] for rank in range(self.group.nranks): if rank != self.group.my_rank: self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw) self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw) + self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8) + self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8) + + self.set_params(nblocks, block_size) + + def __call__(self, stream_ptr): + self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr) + return self.memory + + def set_params(self, nblocks, block_size): + self.nblocks = nblocks + self.block_size = block_size + self.params = b"" self.params += pack( - cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8), - cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8), + self.fst_device_handles_cp, + self.snd_device_handles_cp, self.memory, self.scratch, self.group.my_rank, @@ -158,9 +219,13 @@ class MscclppAllReduce3: ctypes.c_size_t(self.memory.size), ) - def __call__(self, stream_ptr): - self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr) - return self.memory + def auto_tune(self): + nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108] + block_size_to_try = [256, 512, 1024] + for nblocks in nblocks_to_try: + for block_size in block_size_to_try: + self.set_params(nblocks, block_size) + yield nblocks, block_size class MscclppAllReduce4: @@ -220,6 +285,14 @@ class MscclppAllReduce4: ) self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw) + self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8) + self.reduce_sactter_proxy_device_handles_cp = cp.asarray( + memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8 + ) + self.all_gather_proxy_device_handles_cp = cp.asarray( + memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8 + ) + self.set_params(nblocks, block_size, pipeline_depth) def __call__(self, stream_ptr): @@ -233,9 +306,9 @@ class MscclppAllReduce4: self.params = b"" self.params += pack( - cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8), - cp.asarray(memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8), - cp.asarray(memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8), + self.sm_device_handles_cp, + self.reduce_sactter_proxy_device_handles_cp, + self.all_gather_proxy_device_handles_cp, self.memory, self.scratch, self.group.my_rank, @@ -310,6 +383,9 @@ class MscclppAllReduce5: if rank != self.group.my_rank and not in_same_node(rank): self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw) + self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8) + self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8) + self.set_params(nblocks, block_size) def __call__(self, stream_ptr): @@ -322,8 +398,8 @@ class MscclppAllReduce5: self.params = b"" self.params += pack( - cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8), - cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8), + self.sm_device_handles_cp, + self.proxy_device_handles_cp, self.memory, self.scratch, self.put_buff, diff --git a/python/requirements_cu11.txt b/python/requirements_cu11.txt index 47285da3..7f4b4ea1 100644 --- a/python/requirements_cu11.txt +++ b/python/requirements_cu11.txt @@ -5,3 +5,4 @@ cuda-python netifaces pytest numpy +matplotlib diff --git a/python/requirements_cu12.txt b/python/requirements_cu12.txt index 094dff8d..aa657eac 100644 --- a/python/requirements_cu12.txt +++ b/python/requirements_cu12.txt @@ -5,3 +5,4 @@ cuda-python netifaces pytest numpy +matplotlib