mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-08 15:30:41 +00:00
Renaming channels (#436)
Renamed `ProxyChannel` to `PortChannel` and `SmChannel` to `MemoryChannel`
This commit is contained in:
@@ -28,7 +28,7 @@ def allgather_test(gpus, instances):
|
||||
c = chunk(n, Buffer.input, 0, 1)
|
||||
for peer in range(gpus):
|
||||
if n != peer:
|
||||
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
|
||||
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.memory)
|
||||
else:
|
||||
c.copy(n, Buffer.output, n, sendtb=peer)
|
||||
# explicit barrier
|
||||
@@ -36,13 +36,13 @@ def allgather_test(gpus, instances):
|
||||
r.barrier(tb_list=list(range(gpus)))
|
||||
for peer in range(gpus):
|
||||
if n != peer:
|
||||
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
|
||||
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.memory)
|
||||
|
||||
for n in range(gpus):
|
||||
for peer in range(gpus):
|
||||
c = chunk(n, Buffer.output, peer, 1)
|
||||
if n != peer:
|
||||
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm)
|
||||
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.memory)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
@@ -10,9 +10,9 @@ from mscclpp.language.types import ChannelType
|
||||
|
||||
def send_recv(instances):
|
||||
"""
|
||||
Send and receive data between two ranks using proxy channels, with LL protocol and double scratch buffer.
|
||||
Send and receive data between two ranks using port channels, with LL protocol and double scratch buffer.
|
||||
Steps:
|
||||
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via proxy channel.
|
||||
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via port channel.
|
||||
2. Wait for the data to be received, then copy it to the output buffer.
|
||||
"""
|
||||
size = 2
|
||||
@@ -36,7 +36,7 @@ def send_recv(instances):
|
||||
"scratch",
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.proxy,
|
||||
chan_type=ChannelType.port,
|
||||
temp_buffer="scratch",
|
||||
temp_buffer_index=0,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from mscclpp.language.types import ChannelType
|
||||
|
||||
def send_recv(instances):
|
||||
"""
|
||||
Send and receive data between two ranks using proxy channels.
|
||||
Send and receive data between two ranks using port channels.
|
||||
steps:
|
||||
1. Each rank sends a chunk to the other rank's scratch buffer and signals the other rank that the data has been sent.
|
||||
2. Wait for the data to be received then copy it to the output buffer.
|
||||
@@ -34,14 +34,14 @@ def send_recv(instances):
|
||||
"scratch",
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.proxy,
|
||||
chan_type=ChannelType.port,
|
||||
)
|
||||
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
|
||||
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
|
||||
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
|
||||
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
|
||||
|
||||
for r in range(size):
|
||||
c = chunk(r, "scratch", 1)
|
||||
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)
|
||||
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.port)
|
||||
c.copy(r, Buffer.output, 0, sendtb=0)
|
||||
|
||||
Json()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os as _os
|
||||
import os
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
from ._mscclpp import (
|
||||
Env,
|
||||
@@ -22,9 +24,9 @@ from ._mscclpp import (
|
||||
numa,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
ProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
PortChannel,
|
||||
MemoryChannel,
|
||||
MemoryDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
@@ -39,17 +41,82 @@ from ._mscclpp import (
|
||||
npkit,
|
||||
)
|
||||
|
||||
__version__ = version()
|
||||
|
||||
if _os.environ.get("MSCCLPP_HOME", None) is None:
|
||||
_os.environ["MSCCLPP_HOME"] = _os.path.abspath(_os.path.dirname(__file__))
|
||||
__all__ = [
|
||||
"Communicator",
|
||||
"Connection",
|
||||
"connect_nvls_collective",
|
||||
"EndpointConfig",
|
||||
"Fifo",
|
||||
"Host2DeviceSemaphore",
|
||||
"Host2HostSemaphore",
|
||||
"numa",
|
||||
"ProxyService",
|
||||
"RegisteredMemory",
|
||||
"PortChannel",
|
||||
"MemoryChannel",
|
||||
"MemoryDevice2DeviceSemaphore",
|
||||
"TcpBootstrap",
|
||||
"Transport",
|
||||
"TransportFlags",
|
||||
"DataType",
|
||||
"Executor",
|
||||
"ExecutionPlan",
|
||||
"PacketType",
|
||||
"version",
|
||||
"is_nvls_supported",
|
||||
"alloc_shared_physical_cuda",
|
||||
"npkit",
|
||||
"__version__",
|
||||
"get_include",
|
||||
"get_lib",
|
||||
### Deprecated ###
|
||||
"ProxyChannel",
|
||||
"SmChannel",
|
||||
"SmDevice2DeviceSemaphore",
|
||||
]
|
||||
|
||||
__version__: str = str(version())
|
||||
|
||||
if os.environ.get("MSCCLPP_HOME", None) is None:
|
||||
os.environ["MSCCLPP_HOME"] = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
def get_include():
|
||||
def get_include() -> str:
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "include")
|
||||
return os.path.join(os.path.dirname(__file__), "include")
|
||||
|
||||
|
||||
def get_lib():
|
||||
def get_lib() -> str:
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "lib")
|
||||
return os.path.join(os.path.dirname(__file__), "lib")
|
||||
|
||||
|
||||
def deprecated(new_cls):
|
||||
def decorator(old_cls):
|
||||
@wraps(old_cls)
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn(
|
||||
f"{old_cls.__name__} is deprecated, use {new_cls.__name__} instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return new_cls(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@deprecated(PortChannel)
|
||||
class ProxyChannel(PortChannel):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated(MemoryChannel)
|
||||
class SmChannel(MemoryChannel):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated(MemoryDevice2DeviceSemaphore)
|
||||
class SmDevice2DeviceSemaphore(MemoryDevice2DeviceSemaphore):
|
||||
pass
|
||||
|
||||
@@ -14,9 +14,9 @@ from ._mscclpp import (
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
ProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
PortChannel,
|
||||
MemoryChannel,
|
||||
MemoryDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
@@ -135,7 +135,7 @@ class CommGroup:
|
||||
def make_semaphore(
|
||||
self,
|
||||
connections: dict[int, Connection],
|
||||
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[SmDevice2DeviceSemaphore],
|
||||
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[MemoryDevice2DeviceSemaphore],
|
||||
) -> dict[int, Host2HostSemaphore]:
|
||||
semaphores = {}
|
||||
for rank in connections:
|
||||
@@ -143,33 +143,35 @@ class CommGroup:
|
||||
self.communicator.setup()
|
||||
return semaphores
|
||||
|
||||
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]:
|
||||
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
|
||||
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
for rank in connections:
|
||||
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
|
||||
channels[rank] = MemoryChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
|
||||
return channels
|
||||
|
||||
def make_sm_channels_with_scratch(
|
||||
def make_memory_channels_with_scratch(
|
||||
self,
|
||||
tensor: cp.ndarray,
|
||||
scratchTensor: cp.ndarray,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, SmChannel]:
|
||||
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
|
||||
for rank in connections:
|
||||
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
|
||||
channels[rank] = MemoryChannel(
|
||||
semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr
|
||||
)
|
||||
return channels
|
||||
|
||||
def make_proxy_channels(
|
||||
def make_port_channels(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
) -> dict[int, SmChannel]:
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
@@ -180,18 +182,16 @@ class CommGroup:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = proxy_service.proxy_channel(
|
||||
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
|
||||
)
|
||||
channels[rank] = proxy_service.port_channel(semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank])
|
||||
return channels
|
||||
|
||||
def make_proxy_channels_with_scratch(
|
||||
def make_port_channels_with_scratch(
|
||||
self,
|
||||
proxy_service: ProxyService,
|
||||
tensor: cp.ndarray,
|
||||
scratchTensor: cp.ndarray,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, SmChannel]:
|
||||
) -> dict[int, MemoryChannel]:
|
||||
transport_flags = TransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
@@ -218,21 +218,19 @@ class CommGroup:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = proxy_service.proxy_channel(
|
||||
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
|
||||
)
|
||||
channels[rank] = proxy_service.port_channel(semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank])
|
||||
return channels
|
||||
|
||||
def register_semaphore_with_proxy(
|
||||
self, proxy_service: ProxyService, connections: dict[int, Connection]
|
||||
) -> dict[int, SmChannel]:
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphore_ids = {}
|
||||
for rank in semaphores:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = proxy_service.base_proxy_channel(semaphore_ids[rank])
|
||||
channels[rank] = proxy_service.base_port_channel(semaphore_ids[rank])
|
||||
return channels
|
||||
|
||||
def register_memory_with_proxy(
|
||||
|
||||
@@ -15,8 +15,8 @@ using namespace mscclpp;
|
||||
|
||||
extern void register_env(nb::module_& m);
|
||||
extern void register_error(nb::module_& m);
|
||||
extern void register_proxy_channel(nb::module_& m);
|
||||
extern void register_sm_channel(nb::module_& m);
|
||||
extern void register_port_channel(nb::module_& m);
|
||||
extern void register_memory_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
@@ -187,8 +187,8 @@ void register_core(nb::module_& m) {
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
register_env(m);
|
||||
register_error(m);
|
||||
register_proxy_channel(m);
|
||||
register_sm_channel(m);
|
||||
register_port_channel(m);
|
||||
register_memory_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_utils(m);
|
||||
|
||||
@@ -6,7 +6,6 @@ from mscclpp.language.chunk import Chunk, ReduceChunk
|
||||
|
||||
|
||||
class Collective:
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
|
||||
self.num_ranks = num_ranks
|
||||
self.chunk_factor = chunk_factor
|
||||
@@ -36,7 +35,6 @@ class Collective:
|
||||
|
||||
|
||||
class AllToAll(Collective):
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "alltoall"
|
||||
@@ -137,7 +135,6 @@ class AllGather(Collective):
|
||||
|
||||
|
||||
class AllReduce(Collective):
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
|
||||
num_chunk_groups = kwargs.get("num_chunk_groups", num_ranks)
|
||||
Collective.__init__(
|
||||
|
||||
@@ -221,7 +221,7 @@ class InstructionDAG:
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ChannelType.proxy,
|
||||
channel_type=ChannelType.port,
|
||||
step=tb_step,
|
||||
)
|
||||
buffer = send_ref.buffer
|
||||
|
||||
@@ -19,7 +19,6 @@ from mscclpp.language.types import ChunkRef, ChannelType, Instruction, Op, Threa
|
||||
|
||||
|
||||
class _InstructionOptimizer:
|
||||
|
||||
def try_merge_same_instructions(
|
||||
self,
|
||||
op: Op,
|
||||
@@ -128,8 +127,8 @@ class _InstructionOptimizer:
|
||||
and same_tb(op, next_op)
|
||||
and same_count(op, next_op)
|
||||
and buf_dst_src_match(op, next_op)
|
||||
and next_op.channel_type == ChannelType.sm
|
||||
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
|
||||
and next_op.channel_type == ChannelType.memory
|
||||
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.memory)
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
@@ -140,10 +139,10 @@ class _InstructionOptimizer:
|
||||
op.inst = Instruction.read_reduce_copy_send
|
||||
elif op.inst == Instruction.reduce:
|
||||
op.inst = Instruction.reduce_send
|
||||
op.channel_type = ChannelType.sm
|
||||
op.channel_type = ChannelType.memory
|
||||
elif op.inst == Instruction.reduce_packet:
|
||||
op.inst = Instruction.reduce_send_packet
|
||||
op.channel_type = ChannelType.sm
|
||||
op.channel_type = ChannelType.memory
|
||||
# Append the destination chunk from next_op
|
||||
op.dsts.append(
|
||||
(
|
||||
@@ -158,11 +157,11 @@ class _InstructionOptimizer:
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_fuse_instructions_using_proxy_channel(
|
||||
def try_fuse_instructions_using_port_channel(
|
||||
self, op: Op, next_op: Op, tb: Threadblock, queue: list, expected_next_inst: Instruction
|
||||
) -> bool:
|
||||
"""
|
||||
Attempts to fuse operations which using proxy channel.
|
||||
Attempts to fuse operations which using port channel.
|
||||
:param op: The current operation.
|
||||
:param next_op: The next operation to potentially merge with.
|
||||
:param tb: The thread block containing the operations.
|
||||
@@ -177,7 +176,7 @@ class _InstructionOptimizer:
|
||||
and same_buf_dst(op, next_op)
|
||||
and same_buf_src(op, next_op)
|
||||
and same_chan_type(op, next_op)
|
||||
and op.channel_type == ChannelType.proxy
|
||||
and op.channel_type == ChannelType.port
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
@@ -229,7 +228,6 @@ class _InstructionOptimizer:
|
||||
|
||||
|
||||
class DagOptimizer:
|
||||
|
||||
def __init__(self, instruction_dag: InstructionDAG):
|
||||
self.optimizer = _InstructionOptimizer()
|
||||
self.dag = instruction_dag
|
||||
@@ -257,7 +255,7 @@ class DagOptimizer:
|
||||
queue = queue[1:]
|
||||
|
||||
def fuse_instructions(self):
|
||||
self._fuse_instructions_using_proxy_channel()
|
||||
self._fuse_instructions_using_port_channel()
|
||||
self._fuse_same_instructions()
|
||||
self._optimize_rrcs_rs()
|
||||
self._optimize_group_ops()
|
||||
@@ -267,7 +265,7 @@ class DagOptimizer:
|
||||
# -> putWithSignal(src, sbuf, si, dst, dbuf, di)
|
||||
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) flush(src, sbuf, si, dst, dbuf, di)
|
||||
# -> putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
|
||||
def _fuse_instructions_using_proxy_channel(self):
|
||||
def _fuse_instructions_using_port_channel(self):
|
||||
inst_followup_map = {
|
||||
Instruction.put: Instruction.signal,
|
||||
Instruction.put_with_signal: Instruction.flush,
|
||||
@@ -280,7 +278,7 @@ class DagOptimizer:
|
||||
fused = False
|
||||
if op.inst in inst_followup_map:
|
||||
for next_op in op.next:
|
||||
fused = self.optimizer.try_fuse_instructions_using_proxy_channel(
|
||||
fused = self.optimizer.try_fuse_instructions_using_port_channel(
|
||||
op, next_op, tb, queue, inst_followup_map[op.inst]
|
||||
)
|
||||
if fused:
|
||||
|
||||
@@ -286,7 +286,7 @@ class _ReadReduceCopySendConverter(_OpConverter):
|
||||
class _ReduceSendConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
dst_channel_ids = self.get_channel_ids(
|
||||
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm
|
||||
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.memory
|
||||
)
|
||||
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
|
||||
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
|
||||
|
||||
@@ -222,7 +222,7 @@ class Ref(ChunkRef):
|
||||
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
|
||||
return buffer, index
|
||||
|
||||
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
|
||||
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory, use_packet=False):
|
||||
self.prog.check_buffer_exists(dst, buffer)
|
||||
assert self.rank != dst, "Cannot put to the same rank"
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
@@ -237,7 +237,7 @@ class Ref(ChunkRef):
|
||||
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
|
||||
return dst_chunkref
|
||||
|
||||
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
|
||||
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
||||
return self._put(dst, buffer, index, sendtb, chan_type)
|
||||
|
||||
def put_packet(
|
||||
@@ -246,19 +246,19 @@ class Ref(ChunkRef):
|
||||
buffer=None,
|
||||
index=-1,
|
||||
sendtb=-1,
|
||||
chan_type=ChannelType.sm,
|
||||
chan_type=ChannelType.memory,
|
||||
temp_buffer=None,
|
||||
temp_buffer_index=-1,
|
||||
):
|
||||
chunk_ref = self
|
||||
if chan_type == ChannelType.proxy:
|
||||
assert temp_buffer is not None, "Need to specify a temporary buffer for proxy channels"
|
||||
if chan_type == ChannelType.port:
|
||||
assert temp_buffer is not None, "Need to specify a temporary buffer for port channels"
|
||||
chunk_ref = self._copy(
|
||||
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
|
||||
)
|
||||
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)
|
||||
|
||||
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
|
||||
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
||||
self.prog.check_buffer_exists(src, buffer)
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
@@ -273,7 +273,7 @@ class Ref(ChunkRef):
|
||||
# for signal and wait, currently we assuem the pair will use the same tb index. In future we need
|
||||
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
|
||||
# Then we can use DAG info to reduce the number of channels.
|
||||
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
|
||||
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
assert sender != receiver, "Cannot signal to the same rank"
|
||||
@@ -282,9 +282,9 @@ class Ref(ChunkRef):
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)
|
||||
|
||||
# only proxy channel need to use this function
|
||||
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy):
|
||||
assert chan_type == ChannelType.proxy, "Only proxy channel can use flush"
|
||||
# only port channel need to use this function
|
||||
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port):
|
||||
assert chan_type == ChannelType.port, "Only port channel can use flush"
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
assert sender != receiver, "Cannot flush to the same rank"
|
||||
@@ -293,7 +293,7 @@ class Ref(ChunkRef):
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
|
||||
|
||||
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
|
||||
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
assert sender != receiver, "Cannot wait on the same rank"
|
||||
@@ -324,7 +324,7 @@ class Ref(ChunkRef):
|
||||
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
|
||||
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
|
||||
|
||||
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
|
||||
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False):
|
||||
dst = self.rank
|
||||
src = other_chunkref.rank
|
||||
|
||||
@@ -342,7 +342,7 @@ class Ref(ChunkRef):
|
||||
return self
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm):
|
||||
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory):
|
||||
return self._reduce(other_chunkref, recvtb, channel_type)
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
|
||||
@@ -114,11 +114,15 @@ class ChunkRef:
|
||||
|
||||
|
||||
class ChannelType(Enum):
|
||||
proxy = "proxy"
|
||||
sm = "sm"
|
||||
port = "port"
|
||||
memory = "memory"
|
||||
none = "none"
|
||||
nvls = "nvls"
|
||||
|
||||
# Deprecated
|
||||
proxy = "port"
|
||||
sm = "memory"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
35
python/mscclpp/memory_channel_py.cpp
Normal file
35
python/mscclpp/memory_channel_py.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<MemoryChannel> memoryChannel(m, "MemoryChannel");
|
||||
memoryChannel
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, uintptr_t src) { new (memoryChannel) MemoryChannel(semaphore, dst, (void*)src); })
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, uintptr_t src, uintptr_t get_packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
|
||||
})
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &MemoryChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
@@ -5,12 +5,12 @@
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/proxy_channel.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_proxy_channel(nb::module_& m) {
|
||||
void register_port_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
@@ -23,36 +23,36 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("base_proxy_channel", &ProxyService::baseProxyChannel, nb::arg("id"))
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BaseProxyChannel>(m, "BaseProxyChannel")
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BaseProxyChannel::deviceHandle);
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseProxyChannel::DeviceHandle>(m, "BaseProxyChannelDeviceHandle")
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &BaseProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BaseProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &BaseProxyChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const BaseProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
.def_rw("semaphoreId_", &BasePortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &BasePortChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const BasePortChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
|
||||
.def_rw("src_", &ProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &ProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
.def_rw("semaphoreId_", &PortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &PortChannel::DeviceHandle::fifo_)
|
||||
.def_rw("src_", &PortChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &PortChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const PortChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
@@ -33,18 +33,18 @@ void register_semaphore(nb::module_& m) {
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
|
||||
smDevice2DeviceSemaphore
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
.def_rw("inboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/sm_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_sm_channel(nb::module_& m) {
|
||||
nb::class_<SmChannel> smChannel(m, "SmChannel");
|
||||
smChannel
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src, uintptr_t get_packet_buffer) {
|
||||
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
|
||||
})
|
||||
.def("device_handle", &SmChannel::deviceHandle);
|
||||
|
||||
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &SmChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
@@ -8,9 +8,9 @@
|
||||
#endif
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/proxy_channel_device.hpp>
|
||||
#include <mscclpp/sm_channel_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
|
||||
@@ -124,7 +124,7 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
|
||||
// -------------------------------------------
|
||||
|
||||
template <int READ_ONLY>
|
||||
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
|
||||
__device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank, int nranks,
|
||||
size_t nelems) {
|
||||
const size_t chunkSize = nelems / nranks;
|
||||
if (nranks == 1) return;
|
||||
@@ -140,10 +140,10 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < nPeer) {
|
||||
smChans[tid].relaxedSignal();
|
||||
memChans[tid].relaxedSignal();
|
||||
}
|
||||
if (tid >= nPeer && tid < nPeer * 2) {
|
||||
smChans[tid - nPeer].wait();
|
||||
memChans[tid - nPeer].wait();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -155,14 +155,14 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
|
||||
int4 val;
|
||||
int peerIdx = (index + rank);
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
val = memChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
tmp = add_vectors<TYPE>(tmp, val);
|
||||
}
|
||||
if (READ_ONLY == 0) {
|
||||
for (int index = 0; index < nPeer; ++index) {
|
||||
int peerIdx = (index + rank);
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
smChans[peerIdx].write<int4>(indexOffset4 + idx, tmp);
|
||||
memChans[peerIdx].write<int4>(indexOffset4 + idx, tmp);
|
||||
}
|
||||
}
|
||||
buff4[indexOffset4 + idx] = tmp;
|
||||
@@ -178,14 +178,14 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
|
||||
for (int index = 0; index < nPeer; ++index) {
|
||||
int peerIdx = (index + rank);
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
TYPE val = smChans[peerIdx].read<TYPE>(idx);
|
||||
TYPE val = memChans[peerIdx].read<TYPE>(idx);
|
||||
tmp += val;
|
||||
}
|
||||
if (READ_ONLY == 0) {
|
||||
for (int index = 0; index < nPeer; ++index) {
|
||||
int peerIdx = (index + rank);
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
smChans[peerIdx].write<TYPE>(idx, tmp);
|
||||
memChans[peerIdx].write<TYPE>(idx, tmp);
|
||||
}
|
||||
}
|
||||
buff[idx] = tmp;
|
||||
@@ -198,10 +198,10 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < nPeer) {
|
||||
smChans[tid].relaxedSignal();
|
||||
memChans[tid].relaxedSignal();
|
||||
}
|
||||
if (tid >= nPeer && tid < nPeer * 2) {
|
||||
smChans[tid - nPeer].wait();
|
||||
memChans[tid - nPeer].wait();
|
||||
}
|
||||
|
||||
if (READ_ONLY) {
|
||||
@@ -211,17 +211,18 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
const int remoteRank = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
size_t offset = chunkSize * remoteRank * sizeof(TYPE);
|
||||
smChans[peerIdx].get(offset, chunkSize * sizeof(TYPE), tid, blockDim.x * gridDim.x);
|
||||
memChans[peerIdx].get(offset, chunkSize * sizeof(TYPE), tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
|
||||
int rank, int nranks, size_t nelems, int read_only) {
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce1(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank, int nranks, size_t nelems,
|
||||
int read_only) {
|
||||
if (read_only)
|
||||
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
|
||||
allreduce1_helper<1>(memChans, buff, rank, nranks, nelems);
|
||||
else
|
||||
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
|
||||
allreduce1_helper<0>(memChans, buff, rank, nranks, nelems);
|
||||
}
|
||||
|
||||
// -------------------------------------------
|
||||
@@ -231,7 +232,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChan
|
||||
__device__ uint64_t globalFlag = 1;
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
|
||||
allreduce2(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
|
||||
int worldSize, size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
@@ -246,7 +247,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
@@ -259,7 +260,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
@@ -279,7 +280,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
smChans[index].write(offset, packet);
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
@@ -301,7 +302,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
// -------------------------------------------
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce3(mscclpp::ProxyChannelDeviceHandle* fstRoundChans, mscclpp::ProxyChannelDeviceHandle* sndRoundChans,
|
||||
allreduce3(mscclpp::PortChannelDeviceHandle* fstRoundChans, mscclpp::PortChannelDeviceHandle* sndRoundChans,
|
||||
TYPE* buff, TYPE* scratch, int rank, int worldSize, size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
|
||||
@@ -311,10 +312,10 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
int peerSendId = (remoteSendRank < rank) ? remoteSendRank : remoteSendRank - 1;
|
||||
int peerRecvId = (remoteRecvRank < rank) ? remoteRecvRank : remoteRecvRank - 1;
|
||||
|
||||
mscclpp::ProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
|
||||
mscclpp::PortChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
|
||||
mscclpp::PortChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
|
||||
mscclpp::PortChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
|
||||
mscclpp::PortChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
|
||||
|
||||
// Step 1
|
||||
size_t chunkIndex = (rank + worldSize - 1) % worldSize;
|
||||
@@ -419,9 +420,9 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
// AllReduce4
|
||||
// 2-node
|
||||
// -------------------------------------------
|
||||
__device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nRanksPerNode,
|
||||
int startChunkIndex, size_t offsetInChunk, size_t chunkSize, size_t nelems,
|
||||
int nBlocks) {
|
||||
__device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memChans, TYPE* buff, int rank,
|
||||
int nRanksPerNode, int startChunkIndex, size_t offsetInChunk, size_t chunkSize,
|
||||
size_t nelems, int nBlocks) {
|
||||
if (nRanksPerNode == 1) return;
|
||||
if (blockIdx.x >= nBlocks) return;
|
||||
const int nPeer = nRanksPerNode - 1;
|
||||
@@ -433,10 +434,10 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
|
||||
int4* buff4 = (int4*)buff;
|
||||
|
||||
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
|
||||
smChans[peerIdx].relaxedSignal();
|
||||
memChans[peerIdx].relaxedSignal();
|
||||
}
|
||||
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
|
||||
smChans[peerIdx].wait();
|
||||
memChans[peerIdx].wait();
|
||||
}
|
||||
reduceScatterDeviceSyncer.sync(nBlocks);
|
||||
|
||||
@@ -447,7 +448,7 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
|
||||
int4 val;
|
||||
int peerIdx = index + localRankIndexInNode;
|
||||
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
||||
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
val = memChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
tmp = add_vectors<TYPE>(tmp, val);
|
||||
}
|
||||
buff4[indexOffset4 + idx] = tmp;
|
||||
@@ -457,9 +458,9 @@ __device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TY
|
||||
}
|
||||
|
||||
// This kernel is the most performant when the number of blocks is a multiple of (nRanksPerNode - 1).
|
||||
__device__ void localAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
|
||||
int startRankChunkIndex, uint64_t offsetInRankChunk, uint64_t rankChunkSize,
|
||||
uint64_t size, size_t nBlocks) {
|
||||
__device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, int rank, int nRanksPerNode,
|
||||
int startRankChunkIndex, uint64_t offsetInRankChunk, uint64_t rankChunkSize,
|
||||
uint64_t size, size_t nBlocks) {
|
||||
if (nRanksPerNode == 1) return;
|
||||
if (blockIdx.x >= nBlocks) return;
|
||||
const size_t nPeer = nRanksPerNode - 1;
|
||||
@@ -495,16 +496,16 @@ __device__ void localAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int ra
|
||||
sizeForThisBlock += lastChunkSize;
|
||||
}
|
||||
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
|
||||
smChans[peerIdx].relaxedSignal();
|
||||
smChans[peerIdx].wait();
|
||||
memChans[peerIdx].relaxedSignal();
|
||||
memChans[peerIdx].wait();
|
||||
}
|
||||
allGatherDeviceSyncer.sync(nBlocks);
|
||||
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
|
||||
smChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
|
||||
memChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
|
||||
}
|
||||
|
||||
__device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
|
||||
uint64_t size, size_t nBlocks) {
|
||||
__device__ void localAllGatherAllPairsMem(mscclpp::MemoryChannelDeviceHandle* memChans, int rank, int nRanksPerNode,
|
||||
uint64_t size, size_t nBlocks) {
|
||||
if (nRanksPerNode == 1) return;
|
||||
if (blockIdx.x >= nBlocks) return;
|
||||
|
||||
@@ -512,24 +513,24 @@ __device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans
|
||||
const int nPeer = nRanksPerNode - 1;
|
||||
|
||||
if (tid < nPeer) {
|
||||
smChans[tid].signal();
|
||||
memChans[tid].signal();
|
||||
}
|
||||
int waitStart = nBlocks * blockDim.x - nPeer;
|
||||
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
|
||||
smChans[tid - waitStart].wait();
|
||||
memChans[tid - waitStart].wait();
|
||||
}
|
||||
allGatherDeviceSyncer.sync(nBlocks);
|
||||
for (int i = 0; i < nPeer; ++i) {
|
||||
int peerIdx = (i + rank) % nPeer;
|
||||
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
size_t offset = size * remoteRankLocalIndex;
|
||||
smChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
|
||||
memChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
|
||||
}
|
||||
}
|
||||
|
||||
// This is an allgather4 equivalent
|
||||
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
|
||||
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
|
||||
__device__ void allGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* portChans,
|
||||
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
// it is implemented as follows:
|
||||
// Step 1: each node does a local allgather and concurrently,
|
||||
@@ -544,14 +545,14 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::Pr
|
||||
int peerRank = (rank + nRanksPerNode) % worldSize;
|
||||
int peerNodeId = peerRank / nRanksPerNode;
|
||||
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[peer];
|
||||
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
|
||||
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
|
||||
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
|
||||
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;
|
||||
|
||||
if (peerNodeId == rank / nRanksPerNode) {
|
||||
localAllGatherSm(smChans, rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
|
||||
localAllGatherMem(memChans, rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -562,36 +563,37 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::Pr
|
||||
|
||||
// Step 1
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
|
||||
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), step1Bytes);
|
||||
portChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), step1Bytes);
|
||||
}
|
||||
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInLocalNode, 0, rankChunkSize, rankChunkSize,
|
||||
nBlocksForLocalAllGather);
|
||||
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInLocalNode, 0, rankChunkSize, rankChunkSize,
|
||||
nBlocksForLocalAllGather);
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
|
||||
proxyChan.wait();
|
||||
proxyChan.flush();
|
||||
portChan.wait();
|
||||
portChan.flush();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// Step 2
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
|
||||
portChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
|
||||
}
|
||||
if (step1Bytes > 0)
|
||||
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, 0, rankChunkSize, step1Bytes,
|
||||
nBlocksForLocalAllGather);
|
||||
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInPeerNode, 0, rankChunkSize, step1Bytes,
|
||||
nBlocksForLocalAllGather);
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
proxyChan.wait();
|
||||
proxyChan.flush();
|
||||
portChan.wait();
|
||||
portChan.flush();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// Step 3
|
||||
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
|
||||
nBlocksForLocalAllGather);
|
||||
localAllGatherMem(memChans, rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
|
||||
nBlocksForLocalAllGather);
|
||||
}
|
||||
|
||||
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
|
||||
TYPE* buff, TYPE* scratch, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems, // must be divisible by 3
|
||||
int pipelineDepth) {
|
||||
__device__ void reduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
mscclpp::PortChannelDeviceHandle* portChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
int nRanksPerNode, int worldSize,
|
||||
size_t nelems, // must be divisible by 3
|
||||
int pipelineDepth) {
|
||||
// this reduce-scatter algorithm works as follows:
|
||||
// Step 1: each node does a local reduce-scatter on peer node data chunks with 1/pipeline portion of chunk data. For
|
||||
// example, 2 nodes and each node has 2 ranks. rank 0 and rank 1 perform reduce-scatter on chunk 2 and chunk 3, with
|
||||
@@ -612,29 +614,29 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x == nBlocksForReduceScatter);
|
||||
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
||||
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[peer];
|
||||
if (peerNodeId == rank / nRanksPerNode) {
|
||||
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
|
||||
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
|
||||
return;
|
||||
}
|
||||
|
||||
// step 1: local reduce
|
||||
int startChunkIndex = peerNodeId * nRanksPerNode;
|
||||
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize / pipelineSize,
|
||||
nBlocksForReduceScatter);
|
||||
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize / pipelineSize,
|
||||
nBlocksForReduceScatter);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
// step 2: local reduce and exchange data with neighbor
|
||||
if (isComm) {
|
||||
size_t offset = (peerRank * chunkSize) * sizeof(int);
|
||||
// opposite side
|
||||
proxyChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
|
||||
portChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
|
||||
}
|
||||
if (pipelineSize > 1)
|
||||
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
|
||||
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
|
||||
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
|
||||
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
|
||||
if (isComm) {
|
||||
proxyChan.wait();
|
||||
portChan.wait();
|
||||
}
|
||||
if (blockIdx.x >= nBlocksForReduceScatter) {
|
||||
ibDeviceSyncer.sync(nBlocksRemain);
|
||||
@@ -645,7 +647,7 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
|
||||
vectorSum((TYPE*)dst, (TYPE*)src, chunkSize / pipelineSize, blockIdx.x - nBlocksForReduceScatter, nBlocksRemain);
|
||||
}
|
||||
if (isComm) {
|
||||
proxyChan.flush();
|
||||
portChan.flush();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -653,12 +655,12 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
|
||||
startChunkIndex = (rank / nRanksPerNode) * nRanksPerNode;
|
||||
if (isComm && pipelineSize > 1) {
|
||||
size_t offset = (peerRank * chunkSize + chunkSize / pipelineSize) * sizeof(int);
|
||||
proxyChan.putWithSignal(offset, (pipelineSize - 1) * chunkSize / pipelineSize * sizeof(int));
|
||||
portChan.putWithSignal(offset, (pipelineSize - 1) * chunkSize / pipelineSize * sizeof(int));
|
||||
}
|
||||
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize,
|
||||
nBlocksForReduceScatter);
|
||||
localReduceScatterMem(memChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize,
|
||||
nBlocksForReduceScatter);
|
||||
if (isComm && pipelineSize > 1) {
|
||||
proxyChan.wait();
|
||||
portChan.wait();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// reduce to related rank, can not overlap since localReduceScatter also calculate the sum
|
||||
@@ -667,24 +669,24 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp
|
||||
int* src = (int*)((char*)scratch + offset);
|
||||
if (pipelineSize > 1) vectorSum((TYPE*)dst, (TYPE*)src, (pipelineSize - 1) * chunkSize / pipelineSize);
|
||||
if (isComm) {
|
||||
proxyChan.flush();
|
||||
portChan.flush();
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) __global__
|
||||
allreduce4(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* reduceScatterProxyChans,
|
||||
mscclpp::ProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
allreduce4(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* reduceScatterPortChans,
|
||||
mscclpp::PortChannelDeviceHandle* allGatherPortChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
int nRanksPerNode, int worldSize, size_t nelems, int pipelineDepth) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
reduceScatterSm(smChans, reduceScatterProxyChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
|
||||
pipelineDepth);
|
||||
reduceScatterMem(memChans, reduceScatterPortChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
|
||||
pipelineDepth);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
allGatherSm(smChans, allGatherProxyChans, rank, worldSize, nRanksPerNode, nelems / worldSize, pipelineDepth);
|
||||
allGatherMem(memChans, allGatherPortChans, rank, worldSize, nRanksPerNode, nelems / worldSize, pipelineDepth);
|
||||
}
|
||||
|
||||
// allreduce 5 for 2-nodes
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans, TYPE* buff,
|
||||
allreduce5(mscclpp::MemoryChannelDeviceHandle* memChans, mscclpp::PortChannelDeviceHandle* portChans, TYPE* buff,
|
||||
TYPE* scratch, TYPE* putBuff, TYPE* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
@@ -701,8 +703,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
||||
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
@@ -717,8 +719,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
if (nRanksPerNode > 1) {
|
||||
smChan.putPackets(scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer,
|
||||
flag);
|
||||
memChan.putPackets(scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer,
|
||||
flag);
|
||||
}
|
||||
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
|
||||
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
|
||||
@@ -737,9 +739,9 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// step 3. send local reduced data to remote node.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
proxyChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
||||
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
||||
if ((flag & 63) == 0) {
|
||||
proxyChan.flush();
|
||||
portChan.flush();
|
||||
}
|
||||
}
|
||||
// step 4. try to read the data from scratch buffer and write to local peers
|
||||
@@ -756,7 +758,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
smChans[index].write(offset, packet);
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
dst[idx] = res;
|
||||
}
|
||||
@@ -787,7 +789,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
// Barrier among all devices
|
||||
// Should be called by all threads on all devices
|
||||
// Assumes \p num_threads_per_block >= \p num_ranks
|
||||
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
|
||||
__forceinline__ __device__ void barrier(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
|
||||
int block_id, int num_blocks, int num_ranks) {
|
||||
// wait for every device
|
||||
if (block_id == 0) {
|
||||
@@ -804,7 +806,7 @@ __forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceH
|
||||
|
||||
// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
|
||||
template <typename DataType = float, int kVecSize = 8>
|
||||
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
|
||||
int num_ranks, size_t num_elements) {
|
||||
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
|
||||
@@ -863,7 +865,7 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDe
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
allreduce6(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
|
||||
size_t vector_size) {
|
||||
if (vector_size == 8) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import cupy as cp
|
||||
import ctypes
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
|
||||
from mscclpp import Transport, ProxyService, MemoryDevice2DeviceSemaphore
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
|
||||
@@ -48,8 +48,8 @@ class MscclppAllReduce1:
|
||||
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels(self.memory, self.connections)
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.memory_channels = self.group.make_memory_channels(self.memory, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu",
|
||||
@@ -60,7 +60,7 @@ class MscclppAllReduce1:
|
||||
self.device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
self.device_handles.append(self.memory_channels[rank].device_handle().raw)
|
||||
|
||||
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
||||
|
||||
@@ -116,8 +116,8 @@ class MscclppAllReduce2:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections)
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.memory_channels = self.group.make_memory_channels_with_scratch(self.memory, self.scratch, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
@@ -125,7 +125,7 @@ class MscclppAllReduce2:
|
||||
self.device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
self.device_handles.append(self.memory_channels[rank].device_handle().raw)
|
||||
|
||||
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
||||
|
||||
@@ -181,11 +181,11 @@ class MscclppAllReduce3:
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch(
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.fst_round_port_chans = self.group.make_port_channels_with_scratch(
|
||||
self.proxy_service, self.memory, self.scratch, self.connections
|
||||
)
|
||||
self.snd_round_proxy_chans = self.group.make_proxy_channels(self.proxy_service, self.memory, self.connections)
|
||||
self.snd_round_port_chans = self.group.make_port_channels(self.proxy_service, self.memory, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
@@ -194,8 +194,8 @@ class MscclppAllReduce3:
|
||||
self.snd_device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw)
|
||||
self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw)
|
||||
self.fst_device_handles.append(self.fst_round_port_chans[rank].device_handle().raw)
|
||||
self.snd_device_handles.append(self.snd_round_port_chans[rank].device_handle().raw)
|
||||
self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8)
|
||||
self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8)
|
||||
|
||||
@@ -261,31 +261,29 @@ class MscclppAllReduce4:
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections)
|
||||
self.reduce_scatter_proxy_channels = self.group.make_proxy_channels_with_scratch(
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.memory_channels = self.group.make_memory_channels(self.memory, same_node_connections)
|
||||
self.reduce_scatter_port_channels = self.group.make_port_channels_with_scratch(
|
||||
self.proxy_service, self.memory, self.scratch, self.connections
|
||||
)
|
||||
self.all_gather_proxy_channels = self.group.make_proxy_channels(
|
||||
self.proxy_service, self.memory, self.connections
|
||||
)
|
||||
self.all_gather_port_channels = self.group.make_port_channels(self.proxy_service, self.memory, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce4", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
).get_compiled_kernel()
|
||||
self.sm_device_handles = []
|
||||
self.mem_device_handles = []
|
||||
self.reduce_sactter_proxy_device_handles = []
|
||||
self.all_gather_proxy_device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank and in_same_node(rank):
|
||||
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
self.mem_device_handles.append(self.memory_channels[rank].device_handle().raw)
|
||||
if rank != self.group.my_rank:
|
||||
self.reduce_sactter_proxy_device_handles.append(
|
||||
self.reduce_scatter_proxy_channels[rank].device_handle().raw
|
||||
self.reduce_scatter_port_channels[rank].device_handle().raw
|
||||
)
|
||||
self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw)
|
||||
self.all_gather_proxy_device_handles.append(self.all_gather_port_channels[rank].device_handle().raw)
|
||||
|
||||
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
||||
self.mem_device_handles_cp = cp.asarray(memoryview(b"".join(self.mem_device_handles)), dtype=cp.uint8)
|
||||
self.reduce_sactter_proxy_device_handles_cp = cp.asarray(
|
||||
memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8
|
||||
)
|
||||
@@ -306,7 +304,7 @@ class MscclppAllReduce4:
|
||||
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
self.sm_device_handles_cp,
|
||||
self.mem_device_handles_cp,
|
||||
self.reduce_sactter_proxy_device_handles_cp,
|
||||
self.all_gather_proxy_device_handles_cp,
|
||||
self.memory,
|
||||
@@ -366,24 +364,26 @@ class MscclppAllReduce5:
|
||||
self.put_buff = GpuBuffer(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, same_node_connections)
|
||||
self.proxy_channels = self.group.make_proxy_channels_with_scratch(
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.memory_channels = self.group.make_memory_channels_with_scratch(
|
||||
self.memory, self.scratch, same_node_connections
|
||||
)
|
||||
self.port_channels = self.group.make_port_channels_with_scratch(
|
||||
self.proxy_service, self.put_buff, self.scratch, across_node_connections
|
||||
)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu", kernel_name="allreduce5", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
||||
).get_compiled_kernel()
|
||||
self.sm_device_handles = []
|
||||
self.mem_device_handles = []
|
||||
self.proxy_device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank and in_same_node(rank):
|
||||
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
|
||||
self.mem_device_handles.append(self.memory_channels[rank].device_handle().raw)
|
||||
if rank != self.group.my_rank and not in_same_node(rank):
|
||||
self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw)
|
||||
self.proxy_device_handles.append(self.port_channels[rank].device_handle().raw)
|
||||
|
||||
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
||||
self.mem_device_handles_cp = cp.asarray(memoryview(b"".join(self.mem_device_handles)), dtype=cp.uint8)
|
||||
self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8)
|
||||
|
||||
self.set_params(nblocks, block_size)
|
||||
@@ -398,7 +398,7 @@ class MscclppAllReduce5:
|
||||
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
self.sm_device_handles_cp,
|
||||
self.mem_device_handles_cp,
|
||||
self.proxy_device_handles_cp,
|
||||
self.memory,
|
||||
self.scratch,
|
||||
@@ -446,8 +446,8 @@ class MscclppAllReduce6:
|
||||
self.memory.data.ptr, self.memory.data.mem.size
|
||||
)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, MemoryDevice2DeviceSemaphore)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu",
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
// be careful about using semaphore[my_rank] as it is an invalid semaphore and it is there just for simplicity of
|
||||
// indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
d2d_semaphore(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) {
|
||||
d2d_semaphore(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
semaphores[tid].signal();
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/sm_channel_device.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
sm_channel(mscclpp::SmChannelDeviceHandle* channels, int my_rank, int nranks, int num_elements, int use_packet) {
|
||||
memory_channel(mscclpp::MemoryChannelDeviceHandle* channels, int my_rank, int nranks, int num_elements,
|
||||
int use_packet) {
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
|
||||
@@ -10,7 +10,7 @@ __device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
|
||||
mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
|
||||
mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
|
||||
int nelem = nbytes / sizeof(float);
|
||||
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
|
||||
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/packet_device.hpp>
|
||||
#include <mscclpp/proxy_channel_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
|
||||
int num_elements, int use_packet) {
|
||||
port_channel(mscclpp::PortChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
|
||||
int num_elements, int use_packet) {
|
||||
int tid = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
|
||||
@@ -22,7 +22,7 @@ from mscclpp import (
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
SmDevice2DeviceSemaphore,
|
||||
MemoryDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
is_nvls_supported,
|
||||
@@ -363,9 +363,9 @@ class MscclppKernel:
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = nranks
|
||||
elif test_name == "sm_channel":
|
||||
elif test_name == "memory_channel":
|
||||
self._kernel = KernelBuilder(
|
||||
file="sm_channel_test.cu", kernel_name="sm_channel", file_dir=file_dir
|
||||
file="memory_channel_test.cu", kernel_name="memory_channel", file_dir=file_dir
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = nranks
|
||||
self.nthreads = 1024
|
||||
@@ -381,9 +381,9 @@ class MscclppKernel:
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = nranks
|
||||
elif test_name == "proxy_channel":
|
||||
elif test_name == "port_channel":
|
||||
self._kernel = KernelBuilder(
|
||||
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
|
||||
file="port_channel_test.cu", kernel_name="port_channel", file_dir=file_dir
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = 1024
|
||||
@@ -411,11 +411,11 @@ class MscclppKernel:
|
||||
# keep a reference to the device handles so that they don't get garbage collected
|
||||
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
|
||||
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "memory_channel", "port_channel"]:
|
||||
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
|
||||
if test_name == "sm_channel":
|
||||
if test_name == "memory_channel":
|
||||
self.params += pack(tensor.size, use_packet)
|
||||
if test_name == "proxy_channel":
|
||||
if test_name == "port_channel":
|
||||
self.params += pack(tensor, scratch, tensor.size, use_packet)
|
||||
elif test_name == "fifo":
|
||||
self.params = fifo.device_handle().raw
|
||||
@@ -457,7 +457,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore)
|
||||
semaphores = group.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
group.barrier()
|
||||
kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores)
|
||||
kernel()
|
||||
@@ -468,7 +468,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
def test_memory_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
@@ -483,10 +483,10 @@ def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1
|
||||
|
||||
if use_packet:
|
||||
channels = group.make_sm_channels_with_scratch(memory, scratch, connections)
|
||||
channels = group.make_memory_channels_with_scratch(memory, scratch, connections)
|
||||
else:
|
||||
channels = group.make_sm_channels(memory, connections)
|
||||
kernel = MscclppKernel("sm_channel", group.my_rank, group.nranks, channels, memory, use_packet, scratch)
|
||||
channels = group.make_memory_channels(memory, connections)
|
||||
kernel = MscclppKernel("memory_channel", group.my_rank, group.nranks, channels, memory, use_packet, scratch)
|
||||
|
||||
group.barrier()
|
||||
kernel()
|
||||
@@ -565,7 +565,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
@@ -586,10 +586,10 @@ def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_pack
|
||||
memory_to_register = scratch
|
||||
else:
|
||||
memory_to_register = memory
|
||||
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
|
||||
channels = group.make_port_channels(proxy_service, memory_to_register, connections)
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"proxy_channel",
|
||||
"port_channel",
|
||||
my_rank=group.my_rank,
|
||||
nranks=group.nranks,
|
||||
semaphore_or_channels=channels,
|
||||
@@ -614,7 +614,7 @@ def test_nvls(mpi_group: MpiGroup):
|
||||
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
|
||||
|
||||
nvlinks_connections = create_connection(group, "NVLink")
|
||||
semaphores = group.make_semaphore(nvlinks_connections, SmDevice2DeviceSemaphore)
|
||||
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"nvls",
|
||||
|
||||
Reference in New Issue
Block a user