Renaming channels (#436)

Renamed `ProxyChannel` to `PortChannel` and `SmChannel` to
`MemoryChannel`
This commit is contained in:
Changho Hwang
2025-01-24 14:25:31 -08:00
committed by GitHub
parent af0bb86e07
commit 3565bfdf6d
63 changed files with 1372 additions and 1272 deletions

View File

@@ -28,7 +28,7 @@ def allgather_test(gpus, instances):
c = chunk(n, Buffer.input, 0, 1)
for peer in range(gpus):
if n != peer:
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.memory)
else:
c.copy(n, Buffer.output, n, sendtb=peer)
# explicit barrier
@@ -36,13 +36,13 @@ def allgather_test(gpus, instances):
r.barrier(tb_list=list(range(gpus)))
for peer in range(gpus):
if n != peer:
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.memory)
for n in range(gpus):
for peer in range(gpus):
c = chunk(n, Buffer.output, peer, 1)
if n != peer:
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm)
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.memory)
Json()
Check()

View File

@@ -10,9 +10,9 @@ from mscclpp.language.types import ChannelType
def send_recv(instances):
"""
Send and receive data between two ranks using proxy channels, with LL protocol and double scratch buffer.
Send and receive data between two ranks using port channels, with LL protocol and double scratch buffer.
Steps:
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via proxy channel.
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via port channel.
2. Wait for the data to be received, then copy it to the output buffer.
"""
size = 2
@@ -36,7 +36,7 @@ def send_recv(instances):
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
chan_type=ChannelType.port,
temp_buffer="scratch",
temp_buffer_index=0,
)

View File

@@ -10,7 +10,7 @@ from mscclpp.language.types import ChannelType
def send_recv(instances):
"""
Send and receive data between two ranks using proxy channels.
Send and receive data between two ranks using port channels.
steps:
1. Each rank sends a chunk to the other rank's scratch buffer and signals the other rank that the data has been sent.
2. Wait for the data to be received then copy it to the output buffer.
@@ -34,14 +34,14 @@ def send_recv(instances):
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
chan_type=ChannelType.port,
)
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
for r in range(size):
c = chunk(r, "scratch", 1)
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.port)
c.copy(r, Buffer.output, 0, sendtb=0)
Json()

View File

@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os as _os
import os
import warnings
from functools import wraps
from ._mscclpp import (
Env,
@@ -22,9 +24,9 @@ from ._mscclpp import (
numa,
ProxyService,
RegisteredMemory,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
PortChannel,
MemoryChannel,
MemoryDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
TransportFlags,
@@ -39,17 +41,82 @@ from ._mscclpp import (
npkit,
)
__version__ = version()
if _os.environ.get("MSCCLPP_HOME", None) is None:
_os.environ["MSCCLPP_HOME"] = _os.path.abspath(_os.path.dirname(__file__))
__all__ = [
"Communicator",
"Connection",
"connect_nvls_collective",
"EndpointConfig",
"Fifo",
"Host2DeviceSemaphore",
"Host2HostSemaphore",
"numa",
"ProxyService",
"RegisteredMemory",
"PortChannel",
"MemoryChannel",
"MemoryDevice2DeviceSemaphore",
"TcpBootstrap",
"Transport",
"TransportFlags",
"DataType",
"Executor",
"ExecutionPlan",
"PacketType",
"version",
"is_nvls_supported",
"alloc_shared_physical_cuda",
"npkit",
"__version__",
"get_include",
"get_lib",
### Deprecated ###
"ProxyChannel",
"SmChannel",
"SmDevice2DeviceSemaphore",
]
__version__: str = str(version())
if os.environ.get("MSCCLPP_HOME", None) is None:
os.environ["MSCCLPP_HOME"] = os.path.abspath(os.path.dirname(__file__))
def get_include():
def get_include() -> str:
"""Return the directory that contains the MSCCL++ headers."""
return _os.path.join(_os.path.dirname(__file__), "include")
return os.path.join(os.path.dirname(__file__), "include")
def get_lib():
def get_lib() -> str:
"""Return the directory that contains the MSCCL++ headers."""
return _os.path.join(_os.path.dirname(__file__), "lib")
return os.path.join(os.path.dirname(__file__), "lib")
def deprecated(new_cls):
def decorator(old_cls):
@wraps(old_cls)
def wrapper(*args, **kwargs):
warnings.warn(
f"{old_cls.__name__} is deprecated, use {new_cls.__name__} instead.",
DeprecationWarning,
)
return new_cls(*args, **kwargs)
return wrapper
return decorator
@deprecated(PortChannel)
class ProxyChannel(PortChannel):
pass
@deprecated(MemoryChannel)
class SmChannel(MemoryChannel):
pass
@deprecated(MemoryDevice2DeviceSemaphore)
class SmDevice2DeviceSemaphore(MemoryDevice2DeviceSemaphore):
pass

View File

@@ -14,9 +14,9 @@ from ._mscclpp import (
Host2HostSemaphore,
ProxyService,
RegisteredMemory,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
PortChannel,
MemoryChannel,
MemoryDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
TransportFlags,
@@ -135,7 +135,7 @@ class CommGroup:
def make_semaphore(
self,
connections: dict[int, Connection],
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[SmDevice2DeviceSemaphore],
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[MemoryDevice2DeviceSemaphore],
) -> dict[int, Host2HostSemaphore]:
semaphores = {}
for rank in connections:
@@ -143,33 +143,35 @@ class CommGroup:
self.communicator.setup()
return semaphores
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
channels[rank] = MemoryChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
return channels
def make_sm_channels_with_scratch(
def make_memory_channels_with_scratch(
self,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
channels[rank] = MemoryChannel(
semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr
)
return channels
def make_proxy_channels(
def make_port_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
@@ -180,18 +182,16 @@ class CommGroup:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
channels[rank] = proxy_service.port_channel(semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank])
return channels
def make_proxy_channels_with_scratch(
def make_port_channels_with_scratch(
self,
proxy_service: ProxyService,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
) -> dict[int, MemoryChannel]:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
@@ -218,21 +218,19 @@ class CommGroup:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
channels[rank] = proxy_service.port_channel(semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank])
return channels
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphore_ids = {}
for rank in semaphores:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.base_proxy_channel(semaphore_ids[rank])
channels[rank] = proxy_service.base_port_channel(semaphore_ids[rank])
return channels
def register_memory_with_proxy(

View File

@@ -15,8 +15,8 @@ using namespace mscclpp;
extern void register_env(nb::module_& m);
extern void register_error(nb::module_& m);
extern void register_proxy_channel(nb::module_& m);
extern void register_sm_channel(nb::module_& m);
extern void register_port_channel(nb::module_& m);
extern void register_memory_channel(nb::module_& m);
extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
@@ -187,8 +187,8 @@ void register_core(nb::module_& m) {
NB_MODULE(_mscclpp, m) {
register_env(m);
register_error(m);
register_proxy_channel(m);
register_sm_channel(m);
register_port_channel(m);
register_memory_channel(m);
register_fifo(m);
register_semaphore(m);
register_utils(m);

View File

@@ -6,7 +6,6 @@ from mscclpp.language.chunk import Chunk, ReduceChunk
class Collective:
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
self.num_ranks = num_ranks
self.chunk_factor = chunk_factor
@@ -36,7 +35,6 @@ class Collective:
class AllToAll(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "alltoall"
@@ -137,7 +135,6 @@ class AllGather(Collective):
class AllReduce(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
num_chunk_groups = kwargs.get("num_chunk_groups", num_ranks)
Collective.__init__(

View File

@@ -221,7 +221,7 @@ class InstructionDAG:
next=set(),
prev=set(),
tb=tb,
channel_type=ChannelType.proxy,
channel_type=ChannelType.port,
step=tb_step,
)
buffer = send_ref.buffer

View File

@@ -19,7 +19,6 @@ from mscclpp.language.types import ChunkRef, ChannelType, Instruction, Op, Threa
class _InstructionOptimizer:
def try_merge_same_instructions(
self,
op: Op,
@@ -128,8 +127,8 @@ class _InstructionOptimizer:
and same_tb(op, next_op)
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
and next_op.channel_type == ChannelType.memory
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.memory)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
@@ -140,10 +139,10 @@ class _InstructionOptimizer:
op.inst = Instruction.read_reduce_copy_send
elif op.inst == Instruction.reduce:
op.inst = Instruction.reduce_send
op.channel_type = ChannelType.sm
op.channel_type = ChannelType.memory
elif op.inst == Instruction.reduce_packet:
op.inst = Instruction.reduce_send_packet
op.channel_type = ChannelType.sm
op.channel_type = ChannelType.memory
# Append the destination chunk from next_op
op.dsts.append(
(
@@ -158,11 +157,11 @@ class _InstructionOptimizer:
return True
return False
def try_fuse_instructions_using_proxy_channel(
def try_fuse_instructions_using_port_channel(
self, op: Op, next_op: Op, tb: Threadblock, queue: list, expected_next_inst: Instruction
) -> bool:
"""
Attempts to fuse operations which using proxy channel.
Attempts to fuse operations which using port channel.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
@@ -177,7 +176,7 @@ class _InstructionOptimizer:
and same_buf_dst(op, next_op)
and same_buf_src(op, next_op)
and same_chan_type(op, next_op)
and op.channel_type == ChannelType.proxy
and op.channel_type == ChannelType.port
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
@@ -229,7 +228,6 @@ class _InstructionOptimizer:
class DagOptimizer:
def __init__(self, instruction_dag: InstructionDAG):
self.optimizer = _InstructionOptimizer()
self.dag = instruction_dag
@@ -257,7 +255,7 @@ class DagOptimizer:
queue = queue[1:]
def fuse_instructions(self):
self._fuse_instructions_using_proxy_channel()
self._fuse_instructions_using_port_channel()
self._fuse_same_instructions()
self._optimize_rrcs_rs()
self._optimize_group_ops()
@@ -267,7 +265,7 @@ class DagOptimizer:
# -> putWithSignal(src, sbuf, si, dst, dbuf, di)
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) flush(src, sbuf, si, dst, dbuf, di)
# -> putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
def _fuse_instructions_using_proxy_channel(self):
def _fuse_instructions_using_port_channel(self):
inst_followup_map = {
Instruction.put: Instruction.signal,
Instruction.put_with_signal: Instruction.flush,
@@ -280,7 +278,7 @@ class DagOptimizer:
fused = False
if op.inst in inst_followup_map:
for next_op in op.next:
fused = self.optimizer.try_fuse_instructions_using_proxy_channel(
fused = self.optimizer.try_fuse_instructions_using_port_channel(
op, next_op, tb, queue, inst_followup_map[op.inst]
)
if fused:

View File

@@ -286,7 +286,7 @@ class _ReadReduceCopySendConverter(_OpConverter):
class _ReduceSendConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
dst_channel_ids = self.get_channel_ids(
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.memory
)
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))

View File

@@ -222,7 +222,7 @@ class Ref(ChunkRef):
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
return buffer, index
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory, use_packet=False):
self.prog.check_buffer_exists(dst, buffer)
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
@@ -237,7 +237,7 @@ class Ref(ChunkRef):
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
return dst_chunkref
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
return self._put(dst, buffer, index, sendtb, chan_type)
def put_packet(
@@ -246,19 +246,19 @@ class Ref(ChunkRef):
buffer=None,
index=-1,
sendtb=-1,
chan_type=ChannelType.sm,
chan_type=ChannelType.memory,
temp_buffer=None,
temp_buffer_index=-1,
):
chunk_ref = self
if chan_type == ChannelType.proxy:
assert temp_buffer is not None, "Need to specify a temporary buffer for proxy channels"
if chan_type == ChannelType.port:
assert temp_buffer is not None, "Need to specify a temporary buffer for port channels"
chunk_ref = self._copy(
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
)
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
self.prog.check_buffer_exists(src, buffer)
sender = src
receiver = self.rank
@@ -273,7 +273,7 @@ class Ref(ChunkRef):
# for signal and wait, currently we assuem the pair will use the same tb index. In future we need
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
# Then we can use DAG info to reduce the number of channels.
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot signal to the same rank"
@@ -282,9 +282,9 @@ class Ref(ChunkRef):
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)
# only proxy channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy):
assert chan_type == ChannelType.proxy, "Only proxy channel can use flush"
# only port channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port):
assert chan_type == ChannelType.port, "Only port channel can use flush"
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot flush to the same rank"
@@ -293,7 +293,7 @@ class Ref(ChunkRef):
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
sender = src
receiver = self.rank
assert sender != receiver, "Cannot wait on the same rank"
@@ -324,7 +324,7 @@ class Ref(ChunkRef):
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False):
dst = self.rank
src = other_chunkref.rank
@@ -342,7 +342,7 @@ class Ref(ChunkRef):
return self
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm):
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory):
return self._reduce(other_chunkref, recvtb, channel_type)
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref

View File

@@ -114,11 +114,15 @@ class ChunkRef:
class ChannelType(Enum):
proxy = "proxy"
sm = "sm"
port = "port"
memory = "memory"
none = "none"
nvls = "nvls"
# Deprecated
proxy = "port"
sm = "memory"
def __str__(self):
return self.value

View File

@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/memory_channel.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_memory_channel(nb::module_& m) {
nb::class_<MemoryChannel> memoryChannel(m, "MemoryChannel");
memoryChannel
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, uintptr_t src) { new (memoryChannel) MemoryChannel(semaphore, dst, (void*)src); })
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, uintptr_t src, uintptr_t get_packet_buffer) {
new (memoryChannel) MemoryChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
})
.def("device_handle", &MemoryChannel::deviceHandle);
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
.def_rw("getPacketBuffer_", &MemoryChannel::DeviceHandle::getPacketBuffer_)
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -5,12 +5,12 @@
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/proxy_channel.hpp>
#include <mscclpp/port_channel.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_proxy_channel(nb::module_& m) {
void register_port_channel(nb::module_& m) {
nb::class_<BaseProxyService>(m, "BaseProxyService")
.def("start_proxy", &BaseProxyService::startProxy)
.def("stop_proxy", &BaseProxyService::stopProxy);
@@ -23,36 +23,36 @@ void register_proxy_channel(nb::module_& m) {
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("base_proxy_channel", &ProxyService::baseProxyChannel, nb::arg("id"))
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
nb::class_<BaseProxyChannel>(m, "BaseProxyChannel")
nb::class_<BasePortChannel>(m, "BasePortChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &BaseProxyChannel::deviceHandle);
.def("device_handle", &BasePortChannel::deviceHandle);
nb::class_<BaseProxyChannel::DeviceHandle>(m, "BaseProxyChannelDeviceHandle")
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &BaseProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &BaseProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &BaseProxyChannel::DeviceHandle::fifo_)
.def_prop_ro("raw", [](const BaseProxyChannel::DeviceHandle& self) -> nb::bytes {
.def_rw("semaphoreId_", &BasePortChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &BasePortChannel::DeviceHandle::fifo_)
.def_prop_ro("raw", [](const BasePortChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<ProxyChannel>(m, "ProxyChannel")
nb::class_<PortChannel>(m, "PortChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
.def("device_handle", &ProxyChannel::deviceHandle);
.def("device_handle", &PortChannel::deviceHandle);
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
.def_rw("src_", &ProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &ProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
.def_rw("semaphoreId_", &PortChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &PortChannel::DeviceHandle::fifo_)
.def_rw("src_", &PortChannel::DeviceHandle::src_)
.def_rw("dst_", &PortChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const PortChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -33,18 +33,18 @@ void register_semaphore(nb::module_& m) {
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
nb::arg("max_spin_count") = 10000000);
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
smDevice2DeviceSemaphore
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
memoryDevice2DeviceSemaphore
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
.def_rw("inboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("outboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
.def_rw("remoteInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
.def_rw("expectedInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
}

View File

@@ -1,35 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/sm_channel.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_sm_channel(nb::module_& m) {
nb::class_<SmChannel> smChannel(m, "SmChannel");
smChannel
.def("__init__",
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
.def("__init__",
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
uintptr_t src, uintptr_t get_packet_buffer) {
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
})
.def("device_handle", &SmChannel::deviceHandle);
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
.def_rw("src_", &SmChannel::DeviceHandle::src_)
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -8,9 +8,9 @@
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/proxy_channel_device.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/port_channel_device.hpp>
__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
@@ -124,7 +124,7 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
// -------------------------------------------
template <int READ_ONLY>
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
__device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank, int nranks,
size_t nelems) {
const size_t chunkSize = nelems / nranks;
if (nranks == 1) return;
@@ -140,10 +140,10 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
}
__syncthreads();
if (tid < nPeer) {
smChans[tid].relaxedSignal();
memChans[tid].relaxedSignal();
}
if (tid >= nPeer && tid < nPeer * 2) {
smChans[tid - nPeer].wait();
memChans[tid - nPeer].wait();
}
deviceSyncer.sync(gridDim.x);
@@ -155,14 +155,14 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
int4 val;
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
val = memChans[peerIdx].read<int4>(indexOffset4 + idx);
tmp = add_vectors<TYPE>(tmp, val);
}
if (READ_ONLY == 0) {
for (int index = 0; index < nPeer; ++index) {
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
smChans[peerIdx].write<int4>(indexOffset4 + idx, tmp);
memChans[peerIdx].write<int4>(indexOffset4 + idx, tmp);
}
}
buff4[indexOffset4 + idx] = tmp;
@@ -178,14 +178,14 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
for (int index = 0; index < nPeer; ++index) {
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
TYPE val = smChans[peerIdx].read<TYPE>(idx);
TYPE val = memChans[peerIdx].read<TYPE>(idx);
tmp += val;
}
if (READ_ONLY == 0) {
for (int index = 0; index < nPeer; ++index) {
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
smChans[peerIdx].write<TYPE>(idx, tmp);
memChans[peerIdx].write<TYPE>(idx, tmp);
}
}
buff[idx] = tmp;
@@ -198,10 +198,10 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
}
__syncthreads();
if (tid < nPeer) {
smChans[tid].relaxedSignal();
memChans[tid].relaxedSignal();
}
if (tid >= nPeer && tid < nPeer * 2) {
smChans[tid - nPeer].wait();
memChans[tid - nPeer].wait();
}
if (READ_ONLY) {
@@ -211,17 +211,18 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
if (peerIdx >= nPeer) peerIdx -= nPeer;
const int remoteRank = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = chunkSize * remoteRank * sizeof(TYPE);
smChans[peerIdx].get(offset, chunkSize * sizeof(TYPE), tid, blockDim.x * gridDim.x);
memChans[peerIdx].get(offset, chunkSize * sizeof(TYPE), tid, blockDim.x * gridDim.x);
}
}
}
extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
int rank, int nranks, size_t nelems, int read_only) {
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce1(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank, int nranks, size_t nelems,
int read_only) {
if (read_only)
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
allreduce1_helper<1>(memChans, buff, rank, nranks, nelems);
else
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
allreduce1_helper<0>(memChans, buff, rank, nranks, nelems);
}
// -------------------------------------------
@@ -231,7 +232,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChan
__device__ uint64_t globalFlag = 1;
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
allreduce2(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
int worldSize, size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
@@ -246,7 +247,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
@@ -259,7 +260,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
// step 1: write to scratch buffer
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
@@ -279,7 +280,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
smChans[index].write(offset, packet);
memChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
@@ -301,7 +302,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// -------------------------------------------
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce3(mscclpp::ProxyChannelDeviceHandle* fstRoundChans, mscclpp::ProxyChannelDeviceHandle* sndRoundChans,
allreduce3(mscclpp::PortChannelDeviceHandle* fstRoundChans, mscclpp::PortChannelDeviceHandle* sndRoundChans,
TYPE* buff, TYPE* scratch, int rank, int worldSize, size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
@@ -311,10 +312,10 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
int peerSendId = (remoteSendRank < rank) ? remoteSendRank : remoteSendRank - 1;
int peerRecvId = (remoteRecvRank < rank) ? remoteRecvRank : remoteRecvRank - 1;
mscclpp::ProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
mscclpp::ProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
mscclpp::ProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
mscclpp::ProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
mscclpp::PortChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
mscclpp::PortChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
mscclpp::PortChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
mscclpp::PortChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
// Step 1
size_t chunkIndex = (rank + worldSize - 1) % worldSize;
@@ -419,9 +420,9 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// AllReduce4
// 2-node
// -------------------------------------------
__device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nRanksPerNode,
int startChunkIndex, size_t offsetInChunk, size_t chunkSize, size_t nelems,
int nBlocks) {
__device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank,
int nRanksPerNode, int startChunkIndex, size_t offsetInChunk, size_t chunkSize,
size_t nelems, int nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
const int nPeer = nRanksPerNode - 1;
@@ -433,10 +434,10 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
int4* buff4 = (int4*)buff;
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
smChans[peerIdx].relaxedSignal();
memChans[peerIdx].relaxedSignal();
}
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
smChans[peerIdx].wait();
memChans[peerIdx].wait();
}
reduceScatterDeviceSyncer.sync(nBlocks);
@@ -447,7 +448,7 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
int4 val;
int peerIdx = index + localRankIndexInNode;
if (peerIdx >= nPeer) peerIdx -= nPeer;
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
val = memChans[peerIdx].read<int4>(indexOffset4 + idx);
tmp = add_vectors<TYPE>(tmp, val);
}
buff4[indexOffset4 + idx] = tmp;
@@ -457,9 +458,9 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
}
// This kernel is the most performant when the number of blocks is a multiple of (nRanksPerNode - 1).
__device__ void localAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
int startRankChunkIndex, uint64_t offsetInRankChunk, uint64_t rankChunkSize,
uint64_t size, size_t nBlocks) {
__device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, int rank, int nRanksPerNode,
int startRankChunkIndex, uint64_t offsetInRankChunk, uint64_t rankChunkSize,
uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
const size_t nPeer = nRanksPerNode - 1;
@@ -495,16 +496,16 @@ __device__ void localAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int ra
sizeForThisBlock += lastChunkSize;
}
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
smChans[peerIdx].relaxedSignal();
smChans[peerIdx].wait();
memChans[peerIdx].relaxedSignal();
memChans[peerIdx].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
smChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
memChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
}
__device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
uint64_t size, size_t nBlocks) {
__device__ void localAllGatherAllPairsMem(mscclpp::MemoryChannelDeviceHandle* memChans, int rank, int nRanksPerNode,
uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
@@ -512,24 +513,24 @@ __device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans
const int nPeer = nRanksPerNode - 1;
if (tid < nPeer) {
smChans[tid].signal();
memChans[tid].signal();
}
int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
smChans[tid - waitStart].wait();
memChans[tid - waitStart].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
for (int i = 0; i < nPeer; ++i) {
int peerIdx = (i + rank) % nPeer;
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex;
smChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
memChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
}
}
// This is an allgather4 equivalent
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
__device__ void allGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* portChans,
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
// it is implemented as follows:
// Step 1: each node does a local allgather and concurrently,
@@ -544,14 +545,14 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::Pr
int peerRank = (rank + nRanksPerNode) % worldSize;
int peerNodeId = peerRank / nRanksPerNode;
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
mscclpp::PortChannelDeviceHandle portChan = portChans[peer];
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;
if (peerNodeId == rank / nRanksPerNode) {
localAllGatherSm(smChans, rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
localAllGatherMem(memChans, rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
return;
}
@@ -562,36 +563,37 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::Pr
// Step 1
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), step1Bytes);
portChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), step1Bytes);
}
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInLocalNode, 0, rankChunkSize, rankChunkSize,
nBlocksForLocalAllGather);
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInLocalNode, 0, rankChunkSize, rankChunkSize,
nBlocksForLocalAllGather);
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
proxyChan.wait();
proxyChan.flush();
portChan.wait();
portChan.flush();
}
deviceSyncer.sync(gridDim.x);
// Step 2
if (threadIdx.x == 0 && blockIdx.x == 0) {
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
portChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
}
if (step1Bytes > 0)
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, 0, rankChunkSize, step1Bytes,
nBlocksForLocalAllGather);
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInPeerNode, 0, rankChunkSize, step1Bytes,
nBlocksForLocalAllGather);
if (threadIdx.x == 0 && blockIdx.x == 0) {
proxyChan.wait();
proxyChan.flush();
portChan.wait();
portChan.flush();
}
deviceSyncer.sync(gridDim.x);
// Step 3
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
nBlocksForLocalAllGather);
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
nBlocksForLocalAllGather);
}
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
TYPE* buff, TYPE* scratch, int rank, int nRanksPerNode, int worldSize,
size_t nelems, // must be divisible by 3
int pipelineDepth) {
__device__ void reduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memChans,
mscclpp::PortChannelDeviceHandle* portChans, TYPE* buff, TYPE* scratch, int rank,
int nRanksPerNode, int worldSize,
size_t nelems, // must be divisible by 3
int pipelineDepth) {
// this reduce-scatter algorithm works as follows:
// Step 1: each node does a local reduce-scatter on peer node data chunks with 1/pipeline portion of chunk data. For
// example, 2 nodes and each node has 2 ranks. rank 0 and rank 1 perform reduce-scatter on chunk 2 and chunk 3, with
@@ -612,29 +614,29 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
int isComm = (threadIdx.x == 0) && (blockIdx.x == nBlocksForReduceScatter);
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
mscclpp::PortChannelDeviceHandle portChan = portChans[peer];
if (peerNodeId == rank / nRanksPerNode) {
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
return;
}
// step 1: local reduce
int startChunkIndex = peerNodeId * nRanksPerNode;
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize / pipelineSize,
nBlocksForReduceScatter);
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize / pipelineSize,
nBlocksForReduceScatter);
deviceSyncer.sync(gridDim.x);
// step 2: local reduce and exchange data with neighbor
if (isComm) {
size_t offset = (peerRank * chunkSize) * sizeof(int);
// opposite side
proxyChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
portChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
}
if (pipelineSize > 1)
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
if (isComm) {
proxyChan.wait();
portChan.wait();
}
if (blockIdx.x >= nBlocksForReduceScatter) {
ibDeviceSyncer.sync(nBlocksRemain);
@@ -645,7 +647,7 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
vectorSum((TYPE*)dst, (TYPE*)src, chunkSize / pipelineSize, blockIdx.x - nBlocksForReduceScatter, nBlocksRemain);
}
if (isComm) {
proxyChan.flush();
portChan.flush();
}
deviceSyncer.sync(gridDim.x);
@@ -653,12 +655,12 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
startChunkIndex = (rank / nRanksPerNode) * nRanksPerNode;
if (isComm && pipelineSize > 1) {
size_t offset = (peerRank * chunkSize + chunkSize / pipelineSize) * sizeof(int);
proxyChan.putWithSignal(offset, (pipelineSize - 1) * chunkSize / pipelineSize * sizeof(int));
portChan.putWithSignal(offset, (pipelineSize - 1) * chunkSize / pipelineSize * sizeof(int));
}
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize,
nBlocksForReduceScatter);
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize,
nBlocksForReduceScatter);
if (isComm && pipelineSize > 1) {
proxyChan.wait();
portChan.wait();
}
deviceSyncer.sync(gridDim.x);
// reduce to related rank, can not overlap since localReduceScatter also calculate the sum
@@ -667,24 +669,24 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
int* src = (int*)((char*)scratch + offset);
if (pipelineSize > 1) vectorSum((TYPE*)dst, (TYPE*)src, (pipelineSize - 1) * chunkSize / pipelineSize);
if (isComm) {
proxyChan.flush();
portChan.flush();
}
}
extern "C" __global__ void __launch_bounds__(1024, 1) __global__
allreduce4(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* reduceScatterProxyChans,
mscclpp::ProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
allreduce4(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* reduceScatterPortChans,
mscclpp::PortChannelDeviceHandle* allGatherPortChans, TYPE* buff, TYPE* scratch, int rank,
int nRanksPerNode, int worldSize, size_t nelems, int pipelineDepth) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
reduceScatterSm(smChans, reduceScatterProxyChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
pipelineDepth);
reduceScatterMem(memChans, reduceScatterPortChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
pipelineDepth);
deviceSyncer.sync(gridDim.x);
allGatherSm(smChans, allGatherProxyChans, rank, worldSize, nRanksPerNode, nelems / worldSize, pipelineDepth);
allGatherMem(memChans, allGatherPortChans, rank, worldSize, nRanksPerNode, nelems / worldSize, pipelineDepth);
}
// allreduce 5 for 2-nodes
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans, TYPE* buff,
allreduce5(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* portChans, TYPE* buff,
TYPE* scratch, TYPE* putBuff, TYPE* resultBuff, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
@@ -701,8 +703,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
@@ -717,8 +719,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// step 1: write to scratch buffer
if (nRanksPerNode > 1) {
smChan.putPackets(scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer,
flag);
memChan.putPackets(scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer,
flag);
}
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
@@ -737,9 +739,9 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
deviceSyncer.sync(gridDim.x);
// step 3. send local reduced data to remote node.
if (threadIdx.x == 0 && blockIdx.x == 0) {
proxyChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
if ((flag & 63) == 0) {
proxyChan.flush();
portChan.flush();
}
}
// step 4. try to read the data from scratch buffer and write to local peers
@@ -756,7 +758,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
for (int index = 0; index < nPeersInNode; index++) {
smChans[index].write(offset, packet);
memChans[index].write(offset, packet);
}
dst[idx] = res;
}
@@ -787,7 +789,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// Barrier among all devices
// Should be called by all threads on all devices
// Assumes \p num_threads_per_block >= \p num_ranks
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
__forceinline__ __device__ void barrier(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
int block_id, int num_blocks, int num_ranks) {
// wait for every device
if (block_id == 0) {
@@ -804,7 +806,7 @@ __forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceH
// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
template <typename DataType = float, int kVecSize = 8>
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
int num_ranks, size_t num_elements) {
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
@@ -863,7 +865,7 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDe
}
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
allreduce6(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
size_t vector_size) {
if (vector_size == 8) {

View File

@@ -1,7 +1,7 @@
import os
import cupy as cp
import ctypes
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
from mscclpp import Transport, ProxyService, MemoryDevice2DeviceSemaphore
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
@@ -48,8 +48,8 @@ class MscclppAllReduce1:
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
type_str = type_to_str(memory.dtype)
# create a sm_channel for each remote neighbor
self.sm_channels = self.group.make_sm_channels(self.memory, self.connections)
# create a memory_channel for each remote neighbor
self.memory_channels = self.group.make_memory_channels(self.memory, self.connections)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu",
@@ -60,7 +60,7 @@ class MscclppAllReduce1:
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
self.device_handles.append(self.memory_channels[rank].device_handle().raw)
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
@@ -116,8 +116,8 @@ class MscclppAllReduce2:
type_str = type_to_str(memory.dtype)
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
# create a sm_channel for each remote neighbor
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections)
# create a memory_channel for each remote neighbor
self.memory_channels = self.group.make_memory_channels_with_scratch(self.memory, self.scratch, self.connections)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str}
@@ -125,7 +125,7 @@ class MscclppAllReduce2:
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
self.device_handles.append(self.memory_channels[rank].device_handle().raw)
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
@@ -181,11 +181,11 @@ class MscclppAllReduce3:
self.proxy_service = proxy_service
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
# create a sm_channel for each remote neighbor
self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch(
# create a memory_channel for each remote neighbor
self.fst_round_port_chans = self.group.make_port_channels_with_scratch(
self.proxy_service, self.memory, self.scratch, self.connections
)
self.snd_round_proxy_chans = self.group.make_proxy_channels(self.proxy_service, self.memory, self.connections)
self.snd_round_port_chans = self.group.make_port_channels(self.proxy_service, self.memory, self.connections)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str}
@@ -194,8 +194,8 @@ class MscclppAllReduce3:
self.snd_device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw)
self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw)
self.fst_device_handles.append(self.fst_round_port_chans[rank].device_handle().raw)
self.snd_device_handles.append(self.snd_round_port_chans[rank].device_handle().raw)
self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8)
self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8)
@@ -261,31 +261,29 @@ class MscclppAllReduce4:
self.proxy_service = proxy_service
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
# create a sm_channel for each remote neighbor
self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections)
self.reduce_scatter_proxy_channels = self.group.make_proxy_channels_with_scratch(
# create a memory_channel for each remote neighbor
self.memory_channels = self.group.make_memory_channels(self.memory, same_node_connections)
self.reduce_scatter_port_channels = self.group.make_port_channels_with_scratch(
self.proxy_service, self.memory, self.scratch, self.connections
)
self.all_gather_proxy_channels = self.group.make_proxy_channels(
self.proxy_service, self.memory, self.connections
)
self.all_gather_port_channels = self.group.make_port_channels(self.proxy_service, self.memory, self.connections)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce4", file_dir=file_dir, macro_dict={"TYPE": type_str}
).get_compiled_kernel()
self.sm_device_handles = []
self.mem_device_handles = []
self.reduce_sactter_proxy_device_handles = []
self.all_gather_proxy_device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank and in_same_node(rank):
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
self.mem_device_handles.append(self.memory_channels[rank].device_handle().raw)
if rank != self.group.my_rank:
self.reduce_sactter_proxy_device_handles.append(
self.reduce_scatter_proxy_channels[rank].device_handle().raw
self.reduce_scatter_port_channels[rank].device_handle().raw
)
self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw)
self.all_gather_proxy_device_handles.append(self.all_gather_port_channels[rank].device_handle().raw)
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
self.mem_device_handles_cp = cp.asarray(memoryview(b"".join(self.mem_device_handles)), dtype=cp.uint8)
self.reduce_sactter_proxy_device_handles_cp = cp.asarray(
memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8
)
@@ -306,7 +304,7 @@ class MscclppAllReduce4:
self.params = b""
self.params += pack(
self.sm_device_handles_cp,
self.mem_device_handles_cp,
self.reduce_sactter_proxy_device_handles_cp,
self.all_gather_proxy_device_handles_cp,
self.memory,
@@ -366,24 +364,26 @@ class MscclppAllReduce5:
self.put_buff = GpuBuffer(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)}
# create a sm_channel for each remote neighbor
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, same_node_connections)
self.proxy_channels = self.group.make_proxy_channels_with_scratch(
# create a memory_channel for each remote neighbor
self.memory_channels = self.group.make_memory_channels_with_scratch(
self.memory, self.scratch, same_node_connections
)
self.port_channels = self.group.make_port_channels_with_scratch(
self.proxy_service, self.put_buff, self.scratch, across_node_connections
)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu", kernel_name="allreduce5", file_dir=file_dir, macro_dict={"TYPE": type_str}
).get_compiled_kernel()
self.sm_device_handles = []
self.mem_device_handles = []
self.proxy_device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank and in_same_node(rank):
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
self.mem_device_handles.append(self.memory_channels[rank].device_handle().raw)
if rank != self.group.my_rank and not in_same_node(rank):
self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw)
self.proxy_device_handles.append(self.port_channels[rank].device_handle().raw)
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
self.mem_device_handles_cp = cp.asarray(memoryview(b"".join(self.mem_device_handles)), dtype=cp.uint8)
self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8)
self.set_params(nblocks, block_size)
@@ -398,7 +398,7 @@ class MscclppAllReduce5:
self.params = b""
self.params += pack(
self.sm_device_handles_cp,
self.mem_device_handles_cp,
self.proxy_device_handles_cp,
self.memory,
self.scratch,
@@ -446,8 +446,8 @@ class MscclppAllReduce6:
self.memory.data.ptr, self.memory.data.mem.size
)
# create a sm_channel for each remote neighbor
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
# create a memory_channel for each remote neighbor
self.semaphores = group.make_semaphore(self.nvlink_connections, MemoryDevice2DeviceSemaphore)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu",

View File

@@ -6,7 +6,7 @@
// be careful about using semaphore[my_rank] as it is an invalid semaphore and it is there just for simplicity of
// indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
d2d_semaphore(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) {
d2d_semaphore(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) {
int tid = threadIdx.x;
if (tid < nranks && tid != my_rank) {
semaphores[tid].signal();

View File

@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/memory_channel_device.hpp>
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
sm_channel(mscclpp::SmChannelDeviceHandle* channels, int my_rank, int nranks, int num_elements, int use_packet) {
memory_channel(mscclpp::MemoryChannelDeviceHandle* channels, int my_rank, int nranks, int num_elements,
int use_packet) {
int tid = threadIdx.x;
int bid = blockIdx.x;
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;

View File

@@ -10,7 +10,7 @@ __device__ mscclpp::DeviceSyncer deviceSyncer;
extern "C" __global__ void __launch_bounds__(1024, 1)
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
int nelem = nbytes / sizeof(float);
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;

View File

@@ -2,12 +2,12 @@
// Licensed under the MIT license.
#include <mscclpp/packet_device.hpp>
#include <mscclpp/proxy_channel_device.hpp>
#include <mscclpp/port_channel_device.hpp>
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
int num_elements, int use_packet) {
port_channel(mscclpp::PortChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
int num_elements, int use_packet) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;

View File

@@ -22,7 +22,7 @@ from mscclpp import (
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
SmDevice2DeviceSemaphore,
MemoryDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
is_nvls_supported,
@@ -363,9 +363,9 @@ class MscclppKernel:
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = nranks
elif test_name == "sm_channel":
elif test_name == "memory_channel":
self._kernel = KernelBuilder(
file="sm_channel_test.cu", kernel_name="sm_channel", file_dir=file_dir
file="memory_channel_test.cu", kernel_name="memory_channel", file_dir=file_dir
).get_compiled_kernel()
self.nblocks = nranks
self.nthreads = 1024
@@ -381,9 +381,9 @@ class MscclppKernel:
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = nranks
elif test_name == "proxy_channel":
elif test_name == "port_channel":
self._kernel = KernelBuilder(
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
file="port_channel_test.cu", kernel_name="port_channel", file_dir=file_dir
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = 1024
@@ -411,11 +411,11 @@ class MscclppKernel:
# keep a reference to the device handles so that they don't get garbage collected
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
if test_name in ["h2d_semaphore", "d2d_semaphore", "memory_channel", "port_channel"]:
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
if test_name == "sm_channel":
if test_name == "memory_channel":
self.params += pack(tensor.size, use_packet)
if test_name == "proxy_channel":
if test_name == "port_channel":
self.params += pack(tensor, scratch, tensor.size, use_packet)
elif test_name == "fifo":
self.params = fifo.device_handle().raw
@@ -457,7 +457,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
def test_d2d_semaphores(mpi_group: MpiGroup):
group, connections = create_group_and_connection(mpi_group, "NVLink")
semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore)
semaphores = group.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
group.barrier()
kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores)
kernel()
@@ -468,7 +468,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("use_packet", [False, True])
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
def test_memory_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, "NVLink")
memory = GpuBuffer(nelem, dtype=cp.int32)
@@ -483,10 +483,10 @@ def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1
if use_packet:
channels = group.make_sm_channels_with_scratch(memory, scratch, connections)
channels = group.make_memory_channels_with_scratch(memory, scratch, connections)
else:
channels = group.make_sm_channels(memory, connections)
kernel = MscclppKernel("sm_channel", group.my_rank, group.nranks, channels, memory, use_packet, scratch)
channels = group.make_memory_channels(memory, connections)
kernel = MscclppKernel("memory_channel", group.my_rank, group.nranks, channels, memory, use_packet, scratch)
group.barrier()
kernel()
@@ -565,7 +565,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@pytest.mark.parametrize("use_packet", [False, True])
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, transport)
memory = GpuBuffer(nelem, dtype=cp.int32)
@@ -586,10 +586,10 @@ def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_pack
memory_to_register = scratch
else:
memory_to_register = memory
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
channels = group.make_port_channels(proxy_service, memory_to_register, connections)
kernel = MscclppKernel(
"proxy_channel",
"port_channel",
my_rank=group.my_rank,
nranks=group.nranks,
semaphore_or_channels=channels,
@@ -614,7 +614,7 @@ def test_nvls(mpi_group: MpiGroup):
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
nvlinks_connections = create_connection(group, "NVLink")
semaphores = group.make_semaphore(nvlinks_connections, SmDevice2DeviceSemaphore)
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
kernel = MscclppKernel(
"nvls",