mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-08 15:30:41 +00:00
Auto-tune single-node AllReduce (#219)
single node auto-tuner + graph plotter + bug fix for illegal memory access --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -118,12 +118,9 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
|
||||
// AllReduce1
|
||||
// -------------------------------------------
|
||||
|
||||
#ifndef READ_ONLY
|
||||
#define READ_ONLY 0
|
||||
#endif
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, size_t nelems) {
|
||||
template <int READ_ONLY>
|
||||
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
|
||||
size_t nelems) {
|
||||
const size_t chunkSize = nelems / nranks;
|
||||
if (nranks == 1) return;
|
||||
const int nPeer = nranks - 1;
|
||||
@@ -211,13 +208,21 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
|
||||
int rank, int nranks, size_t nelems, int read_only) {
|
||||
if (read_only)
|
||||
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
|
||||
else
|
||||
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
|
||||
}
|
||||
|
||||
// -------------------------------------------
|
||||
// AllReduce2
|
||||
// -------------------------------------------
|
||||
|
||||
__device__ uint64_t globalFlag = 1;
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(512, 1)
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
|
||||
int worldSize, size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
|
||||
@@ -23,6 +23,46 @@ else:
|
||||
raise RuntimeError("Unknown data type")
|
||||
|
||||
|
||||
def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
human_readable_sizes = [human_readable_size(size) for size in sizes]
|
||||
|
||||
fig, ax1 = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
|
||||
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
|
||||
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
|
||||
ax1.set_ylabel("AlgBW (GB/s)")
|
||||
ax1.set_xlabel("Data Size")
|
||||
|
||||
# Logarithmic x-axis
|
||||
ax1.set_xscale("log", base=2)
|
||||
ax1.set_xticks(sizes)
|
||||
ax1.set_xticklabels(human_readable_sizes, rotation=45)
|
||||
|
||||
# Adding secondary y-axis for Speed Up
|
||||
ax2 = ax1.twinx()
|
||||
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
|
||||
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
|
||||
ax2.tick_params(axis="y", labelcolor="green")
|
||||
|
||||
# Set the lower bound of the secondary y-axis to 0
|
||||
ax2.set_ylim(bottom=0)
|
||||
|
||||
# Creating legends
|
||||
lines = [line1, line2, line3]
|
||||
labels = [line.get_label() for line in lines]
|
||||
ax1.legend(lines, labels, loc="upper left")
|
||||
|
||||
# Setting title and grid
|
||||
ax1.set_title("MSCCLPP vs NCCL -- " + str(MPI.COMM_WORLD.size // N_GPUS_PER_NODE) + " Nodes")
|
||||
ax2.grid(True, which="both", ls="--")
|
||||
|
||||
# Saving the plot
|
||||
plt.savefig("mscclpp_vs_nccl_comparison.pdf", format="pdf")
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
@@ -99,15 +139,12 @@ def run_benchmark(
|
||||
memory_out = cp.zeros(nelem, dtype=data_type)
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
|
||||
proxy_service = None
|
||||
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
||||
if memory.nbytes < 2**20:
|
||||
mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out)
|
||||
elif memory.nbytes < 2**29:
|
||||
if memory.nbytes >= 2**20 and memory.nbytes <= 2**22:
|
||||
read_only = 0
|
||||
else:
|
||||
read_only = 1
|
||||
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory, read_only=read_only)
|
||||
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory)
|
||||
else:
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service)
|
||||
@@ -117,14 +154,13 @@ def run_benchmark(
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)
|
||||
proxy_service.start_proxy()
|
||||
best_config = find_best_config(mscclpp_call, 100)
|
||||
mscclpp_call.set_params(*best_config)
|
||||
else:
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)
|
||||
proxy_service.start_proxy()
|
||||
best_config = find_best_config(mscclpp_call, 20)
|
||||
mscclpp_call.set_params(*best_config)
|
||||
|
||||
best_config = find_best_config(mscclpp_call, 20)
|
||||
mscclpp_call.set_params(*best_config)
|
||||
|
||||
nccl_call = NcclAllReduce(nccl_op, memory)
|
||||
|
||||
@@ -145,6 +181,7 @@ def run_benchmark(
|
||||
MPI.COMM_WORLD.barrier()
|
||||
proxy_service.stop_proxy()
|
||||
|
||||
speed_up = nccl_time / mscclpp_time
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
table.add_row(
|
||||
[
|
||||
@@ -155,12 +192,14 @@ def run_benchmark(
|
||||
"{:.2f}".format(nccl_time),
|
||||
"{:.2f}".format(nccl_algBw),
|
||||
nccl_check,
|
||||
"{:.2f}".format(nccl_time / mscclpp_time),
|
||||
"{:.2f}".format(speed_up),
|
||||
]
|
||||
)
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
||||
@@ -200,16 +239,29 @@ if __name__ == "__main__":
|
||||
"Speed Up",
|
||||
]
|
||||
|
||||
for i in range(10, 28):
|
||||
sizes = []
|
||||
mscclpp_algbw = []
|
||||
nccl_algbw = []
|
||||
speed_ups = []
|
||||
for i in range(10, 30):
|
||||
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
||||
run_benchmark(mscclpp_group, nccl_comm, table, 100, 2**i)
|
||||
nelems = 2**i
|
||||
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
|
||||
run_benchmark(mscclpp_group, nccl_comm, table, 100, 3 * 2**i)
|
||||
nelems = 3 * 2**i
|
||||
else:
|
||||
raise RuntimeError("Only support one node/two nodes communication")
|
||||
|
||||
size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
|
||||
sizes.append(size)
|
||||
mscclpp_algbw.append(mscclpp_algBw)
|
||||
nccl_algbw.append(nccl_algBw)
|
||||
speed_ups.append(speed_up)
|
||||
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
print()
|
||||
print(table)
|
||||
|
||||
plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)
|
||||
|
||||
mscclpp_group = None
|
||||
nccl_comm = None
|
||||
|
||||
@@ -35,7 +35,7 @@ class MscclppAllReduce1:
|
||||
group: mscclpp_comm.CommGroup,
|
||||
memory: cp.ndarray,
|
||||
read_only: int = 1,
|
||||
nthreads: int = 1024,
|
||||
block_size: int = 1024,
|
||||
nblocks: int = 24,
|
||||
):
|
||||
self.group = group
|
||||
@@ -55,30 +55,55 @@ class MscclppAllReduce1:
|
||||
file="allreduce.cu",
|
||||
kernel_name="allreduce1",
|
||||
file_dir=file_dir,
|
||||
macro_dict={"TYPE": type_str, "READ_ONLY": str(read_only)},
|
||||
macro_dict={"TYPE": type_str},
|
||||
).get_compiled_kernel()
|
||||
self.params = b""
|
||||
self.device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
|
||||
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
||||
|
||||
self.set_params(nblocks, block_size, read_only)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
|
||||
return self.memory
|
||||
|
||||
def set_params(self, nblocks, block_size, read_only):
|
||||
self.nblocks = nblocks
|
||||
self.block_size = block_size
|
||||
self.read_only = read_only
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8),
|
||||
self.device_handles_cp,
|
||||
self.memory,
|
||||
self.group.my_rank,
|
||||
self.group.nranks,
|
||||
ctypes.c_size_t(self.memory.size),
|
||||
self.read_only,
|
||||
)
|
||||
self.nthreads = nthreads
|
||||
self.nblocks = nblocks
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, stream_ptr)
|
||||
return self.memory
|
||||
def auto_tune(self):
|
||||
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
||||
block_size_to_try = [256, 512, 1024]
|
||||
read_only_to_try = [0, 1]
|
||||
for nblocks in nblocks_to_try:
|
||||
for block_size in block_size_to_try:
|
||||
for read_only in read_only_to_try:
|
||||
self.set_params(nblocks, block_size, read_only)
|
||||
yield nblocks, block_size, read_only
|
||||
|
||||
|
||||
class MscclppAllReduce2:
|
||||
def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, memory_out: cp.ndarray):
|
||||
def __init__(
|
||||
self,
|
||||
group: mscclpp_comm.CommGroup,
|
||||
memory: cp.ndarray,
|
||||
memory_out: cp.ndarray,
|
||||
block_size: int = 512,
|
||||
nblocks: int = 21,
|
||||
):
|
||||
self.group = group
|
||||
self.memory = memory
|
||||
self.memory_out = memory_out
|
||||
@@ -97,13 +122,26 @@ class MscclppAllReduce2:
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
).get_compiled_kernel()
|
||||
self.params = b""
|
||||
self.device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
|
||||
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
||||
|
||||
self.set_params(nblocks, block_size)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
|
||||
return self.memory_out
|
||||
|
||||
def set_params(self, nblocks, block_size):
|
||||
self.nblocks = nblocks
|
||||
self.block_size = block_size
|
||||
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8),
|
||||
self.device_handles_cp,
|
||||
self.memory,
|
||||
self.scratch,
|
||||
self.memory_out,
|
||||
@@ -112,13 +150,24 @@ class MscclppAllReduce2:
|
||||
ctypes.c_size_t(self.memory.size),
|
||||
)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, 21, 512, 0, stream_ptr)
|
||||
return self.memory_out
|
||||
def auto_tune(self):
|
||||
nblocks_to_try = [21, 42, 63, 84, 105]
|
||||
block_size_to_try = [256, 512, 1024]
|
||||
for nblocks in nblocks_to_try:
|
||||
for block_size in block_size_to_try:
|
||||
self.set_params(nblocks, block_size)
|
||||
yield nblocks, block_size
|
||||
|
||||
|
||||
class MscclppAllReduce3:
|
||||
def __init__(self, group: mscclpp_comm.CommGroup, memory: cp.ndarray, proxy_service: ProxyService):
|
||||
def __init__(
|
||||
self,
|
||||
group: mscclpp_comm.CommGroup,
|
||||
memory: cp.ndarray,
|
||||
proxy_service: ProxyService,
|
||||
block_size: int = 1024,
|
||||
nblocks: int = 24,
|
||||
):
|
||||
self.group = group
|
||||
self.memory = memory
|
||||
remote_nghrs = list(range(self.group.nranks))
|
||||
@@ -141,16 +190,28 @@ class MscclppAllReduce3:
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
).get_compiled_kernel()
|
||||
self.params = b""
|
||||
self.fst_device_handles = []
|
||||
self.snd_device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw)
|
||||
self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw)
|
||||
self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8)
|
||||
self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8)
|
||||
|
||||
self.set_params(nblocks, block_size)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)
|
||||
return self.memory
|
||||
|
||||
def set_params(self, nblocks, block_size):
|
||||
self.nblocks = nblocks
|
||||
self.block_size = block_size
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8),
|
||||
cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8),
|
||||
self.fst_device_handles_cp,
|
||||
self.snd_device_handles_cp,
|
||||
self.memory,
|
||||
self.scratch,
|
||||
self.group.my_rank,
|
||||
@@ -158,9 +219,13 @@ class MscclppAllReduce3:
|
||||
ctypes.c_size_t(self.memory.size),
|
||||
)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)
|
||||
return self.memory
|
||||
def auto_tune(self):
|
||||
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
||||
block_size_to_try = [256, 512, 1024]
|
||||
for nblocks in nblocks_to_try:
|
||||
for block_size in block_size_to_try:
|
||||
self.set_params(nblocks, block_size)
|
||||
yield nblocks, block_size
|
||||
|
||||
|
||||
class MscclppAllReduce4:
|
||||
@@ -220,6 +285,14 @@ class MscclppAllReduce4:
|
||||
)
|
||||
self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw)
|
||||
|
||||
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
||||
self.reduce_sactter_proxy_device_handles_cp = cp.asarray(
|
||||
memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8
|
||||
)
|
||||
self.all_gather_proxy_device_handles_cp = cp.asarray(
|
||||
memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8
|
||||
)
|
||||
|
||||
self.set_params(nblocks, block_size, pipeline_depth)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
@@ -233,9 +306,9 @@ class MscclppAllReduce4:
|
||||
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8),
|
||||
cp.asarray(memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8),
|
||||
cp.asarray(memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8),
|
||||
self.sm_device_handles_cp,
|
||||
self.reduce_sactter_proxy_device_handles_cp,
|
||||
self.all_gather_proxy_device_handles_cp,
|
||||
self.memory,
|
||||
self.scratch,
|
||||
self.group.my_rank,
|
||||
@@ -310,6 +383,9 @@ class MscclppAllReduce5:
|
||||
if rank != self.group.my_rank and not in_same_node(rank):
|
||||
self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw)
|
||||
|
||||
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
||||
self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8)
|
||||
|
||||
self.set_params(nblocks, block_size)
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
@@ -322,8 +398,8 @@ class MscclppAllReduce5:
|
||||
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8),
|
||||
cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8),
|
||||
self.sm_device_handles_cp,
|
||||
self.proxy_device_handles_cp,
|
||||
self.memory,
|
||||
self.scratch,
|
||||
self.put_buff,
|
||||
|
||||
@@ -5,3 +5,4 @@ cuda-python
|
||||
netifaces
|
||||
pytest
|
||||
numpy
|
||||
matplotlib
|
||||
|
||||
@@ -5,3 +5,4 @@ cuda-python
|
||||
netifaces
|
||||
pytest
|
||||
numpy
|
||||
matplotlib
|
||||
|
||||
Reference in New Issue
Block a user