Auto-tune single-node AllReduce (#219)

single node auto-tuner + graph plotter + bug fix for illegal memory access

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Saeed Maleki
2023-11-17 05:42:05 -08:00
committed by GitHub
parent 060fda12e6
commit 1d1199703a
5 changed files with 182 additions and 47 deletions

View File

@@ -118,12 +118,9 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
// AllReduce1
// -------------------------------------------
#ifndef READ_ONLY
#define READ_ONLY 0
#endif
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, size_t nelems) {
template <int READ_ONLY>
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
size_t nelems) {
const size_t chunkSize = nelems / nranks;
if (nranks == 1) return;
const int nPeer = nranks - 1;
@@ -211,13 +208,21 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
}
}
extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
int rank, int nranks, size_t nelems, int read_only) {
if (read_only)
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
else
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
}
// -------------------------------------------
// AllReduce2
// -------------------------------------------
__device__ uint64_t globalFlag = 1;
extern "C" __global__ void __launch_bounds__(512, 1)
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
int worldSize, size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));

View File

@@ -23,6 +23,46 @@ else:
raise RuntimeError("Unknown data type")
def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
import matplotlib.pyplot as plt
human_readable_sizes = [human_readable_size(size) for size in sizes]
fig, ax1 = plt.subplots(figsize=(10, 6))
# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
ax1.set_ylabel("AlgBW (GB/s)")
ax1.set_xlabel("Data Size")
# Logarithmic x-axis
ax1.set_xscale("log", base=2)
ax1.set_xticks(sizes)
ax1.set_xticklabels(human_readable_sizes, rotation=45)
# Adding secondary y-axis for Speed Up
ax2 = ax1.twinx()
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
ax2.tick_params(axis="y", labelcolor="green")
# Set the lower bound of the secondary y-axis to 0
ax2.set_ylim(bottom=0)
# Creating legends
lines = [line1, line2, line3]
labels = [line.get_label() for line in lines]
ax1.legend(lines, labels, loc="upper left")
# Setting title and grid
ax1.set_title("MSCCLPP vs NCCL -- " + str(MPI.COMM_WORLD.size // N_GPUS_PER_NODE) + " Nodes")
ax2.grid(True, which="both", ls="--")
# Saving the plot
plt.savefig("mscclpp_vs_nccl_comparison.pdf", format="pdf")
def human_readable_size(size, decimal_places=1):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
@@ -99,15 +139,12 @@ def run_benchmark(
memory_out = cp.zeros(nelem, dtype=data_type)
cp.cuda.runtime.deviceSynchronize()
proxy_service = None
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
if memory.nbytes < 2**20:
mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out)
elif memory.nbytes < 2**29:
if memory.nbytes >= 2**20 and memory.nbytes <= 2**22:
read_only = 0
else:
read_only = 1
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory, read_only=read_only)
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory)
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service)
@@ -117,14 +154,13 @@ def run_benchmark(
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
best_config = find_best_config(mscclpp_call, 100)
mscclpp_call.set_params(*best_config)
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
best_config = find_best_config(mscclpp_call, 20)
mscclpp_call.set_params(*best_config)
best_config = find_best_config(mscclpp_call, 20)
mscclpp_call.set_params(*best_config)
nccl_call = NcclAllReduce(nccl_op, memory)
@@ -145,6 +181,7 @@ def run_benchmark(
MPI.COMM_WORLD.barrier()
proxy_service.stop_proxy()
speed_up = nccl_time / mscclpp_time
if MPI.COMM_WORLD.rank == 0:
table.add_row(
[
@@ -155,12 +192,14 @@ def run_benchmark(
"{:.2f}".format(nccl_time),
"{:.2f}".format(nccl_algBw),
nccl_check,
"{:.2f}".format(nccl_time / mscclpp_time),
"{:.2f}".format(speed_up),
]
)
if MPI.COMM_WORLD.rank == 0:
print(".", end="", flush=True)
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
if __name__ == "__main__":
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
@@ -200,16 +239,29 @@ if __name__ == "__main__":
"Speed Up",
]
for i in range(10, 28):
sizes = []
mscclpp_algbw = []
nccl_algbw = []
speed_ups = []
for i in range(10, 30):
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
run_benchmark(mscclpp_group, nccl_comm, table, 100, 2**i)
nelems = 2**i
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
run_benchmark(mscclpp_group, nccl_comm, table, 100, 3 * 2**i)
nelems = 3 * 2**i
else:
raise RuntimeError("Only support one node/two nodes communication")
size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
sizes.append(size)
mscclpp_algbw.append(mscclpp_algBw)
nccl_algbw.append(nccl_algBw)
speed_ups.append(speed_up)
if MPI.COMM_WORLD.rank == 0:
print()
print(table)
plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)
mscclpp_group = None
nccl_comm = None

View File

@@ -35,7 +35,7 @@ class MscclppAllReduce1:
group: mscclpp_comm.CommGroup,
memory: cp.ndarray,
read_only: int = 1,
nthreads: int = 1024,
block_size: int = 1024,
nblocks: int = 24,
):
self.group = group
@@ -55,30 +55,55 @@ class MscclppAllReduce1:
file="allreduce.cu",
kernel_name="allreduce1",
file_dir=file_dir,
macro_dict={"TYPE": type_str, "READ_ONLY": str(read_only)},
macro_dict={"TYPE": type_str},
).get_compiled_kernel()
self.params = b""
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
self.set_params(nblocks, block_size, read_only)
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
return self.memory
def set_params(self, nblocks, block_size, read_only):
self.nblocks = nblocks
self.block_size = block_size
self.read_only = read_only
self.params = b""
self.params += pack(
cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8),
self.device_handles_cp,
self.memory,
self.group.my_rank,
self.group.nranks,
ctypes.c_size_t(self.memory.size),
self.read_only,
)
self.nthreads = nthreads
self.nblocks = nblocks
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, stream_ptr)
return self.memory
def auto_tune(self):
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
block_size_to_try = [256, 512, 1024]
read_only_to_try = [0, 1]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
for read_only in read_only_to_try:
self.set_params(nblocks, block_size, read_only)
yield nblocks, block_size, read_only
class MscclppAllReduce2:
def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, memory_out: cp.ndarray):
def __init__(
self,
group: mscclpp_comm.CommGroup,
memory: cp.ndarray,
memory_out: cp.ndarray,
block_size: int = 512,
nblocks: int = 21,
):
self.group = group
self.memory = memory
self.memory_out = memory_out
@@ -97,13 +122,26 @@ class MscclppAllReduce2:
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str}
).get_compiled_kernel()
self.params = b""
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
self.set_params(nblocks, block_size)
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
return self.memory_out
def set_params(self, nblocks, block_size):
self.nblocks = nblocks
self.block_size = block_size
self.params = b""
self.params += pack(
cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8),
self.device_handles_cp,
self.memory,
self.scratch,
self.memory_out,
@@ -112,13 +150,24 @@ class MscclppAllReduce2:
ctypes.c_size_t(self.memory.size),
)
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, 21, 512, 0, stream_ptr)
return self.memory_out
def auto_tune(self):
nblocks_to_try = [21, 42, 63, 84, 105]
block_size_to_try = [256, 512, 1024]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size
class MscclppAllReduce3:
def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, proxy_service: ProxyService):
def __init__(
self,
group: mscclpp_comm.CommGroup,
memory: cp.ndarray,
proxy_service: ProxyService,
block_size: int = 1024,
nblocks: int = 24,
):
self.group = group
self.memory = memory
remote_nghrs = list(range(self.group.nranks))
@@ -141,16 +190,28 @@ class MscclppAllReduce3:
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str}
).get_compiled_kernel()
self.params = b""
self.fst_device_handles = []
self.snd_device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw)
self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw)
self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8)
self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8)
self.set_params(nblocks, block_size)
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)
return self.memory
def set_params(self, nblocks, block_size):
self.nblocks = nblocks
self.block_size = block_size
self.params = b""
self.params += pack(
cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8),
cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8),
self.fst_device_handles_cp,
self.snd_device_handles_cp,
self.memory,
self.scratch,
self.group.my_rank,
@@ -158,9 +219,13 @@ class MscclppAllReduce3:
ctypes.c_size_t(self.memory.size),
)
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)
return self.memory
def auto_tune(self):
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
block_size_to_try = [256, 512, 1024]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size
class MscclppAllReduce4:
@@ -220,6 +285,14 @@ class MscclppAllReduce4:
)
self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw)
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
self.reduce_sactter_proxy_device_handles_cp = cp.asarray(
memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8
)
self.all_gather_proxy_device_handles_cp = cp.asarray(
memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8
)
self.set_params(nblocks, block_size, pipeline_depth)
def __call__(self, stream_ptr):
@@ -233,9 +306,9 @@ class MscclppAllReduce4:
self.params = b""
self.params += pack(
cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8),
cp.asarray(memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8),
cp.asarray(memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8),
self.sm_device_handles_cp,
self.reduce_sactter_proxy_device_handles_cp,
self.all_gather_proxy_device_handles_cp,
self.memory,
self.scratch,
self.group.my_rank,
@@ -310,6 +383,9 @@ class MscclppAllReduce5:
if rank != self.group.my_rank and not in_same_node(rank):
self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw)
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8)
self.set_params(nblocks, block_size)
def __call__(self, stream_ptr):
@@ -322,8 +398,8 @@ class MscclppAllReduce5:
self.params = b""
self.params += pack(
cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8),
cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8),
self.sm_device_handles_cp,
self.proxy_device_handles_cp,
self.memory,
self.scratch,
self.put_buff,

View File

@@ -5,3 +5,4 @@ cuda-python
netifaces
pytest
numpy
matplotlib

View File

@@ -5,3 +5,4 @@ cuda-python
netifaces
pytest
numpy
matplotlib