mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Torch integration (#692)
Reorganize current native algorithm implementation and DSL algorithm implementation. Provide unified API for DSL algo and native algo and provide interface to tune the algo Provide interface for pytorch integration with native API and DSL --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from mscclpp import (
|
||||
npkit,
|
||||
env,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp import CommGroup, GpuBuffer
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
import os
|
||||
import struct
|
||||
@@ -180,7 +180,7 @@ def main(
|
||||
n_iters: int = 10,
|
||||
n_graph_iters: int = 10,
|
||||
):
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
mscclpp_group = CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
npkit_dump_dir = env().npkit_dump_dir
|
||||
|
||||
@@ -13,7 +13,6 @@ import pytest
|
||||
|
||||
from mscclpp import (
|
||||
ErrorCode,
|
||||
Error,
|
||||
DataType,
|
||||
EndpointConfig,
|
||||
ExecutionPlan,
|
||||
@@ -31,8 +30,8 @@ from mscclpp import (
|
||||
Device,
|
||||
DeviceType,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
from mscclpp import CommGroup, GpuBuffer
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
|
||||
@@ -75,7 +74,7 @@ def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
|
||||
# ranks are on different nodes
|
||||
pytest.skip("this case is not supported as localhost will be different for different nodes")
|
||||
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm, ifIpPortTrio)
|
||||
group = CommGroup(mpi_group.comm, ifIpPortTrio)
|
||||
|
||||
nelem = 1024
|
||||
memory = np.zeros(nelem, dtype=np.int32)
|
||||
@@ -141,7 +140,7 @@ def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
|
||||
mpi_group.comm.barrier()
|
||||
|
||||
|
||||
def create_connection(group: mscclpp_comm.CommGroup, connection_type: str):
|
||||
def create_connection(group: CommGroup, connection_type: str):
|
||||
if connection_type == "NVLS":
|
||||
all_ranks = list(range(group.nranks))
|
||||
tran = Transport.CudaIpc
|
||||
@@ -163,7 +162,7 @@ def create_connection(group: mscclpp_comm.CommGroup, connection_type: str):
|
||||
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
|
||||
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink/nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
group = CommGroup(mpi_group.comm)
|
||||
try:
|
||||
connection = create_connection(group, connection_type)
|
||||
except Error as e:
|
||||
@@ -282,7 +281,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str,
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
group = CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
@@ -302,7 +301,7 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
group = CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
@@ -339,7 +338,7 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
def test_nvls_connection(mpi_group: MpiGroup):
|
||||
if all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
group = CommGroup(mpi_group.comm)
|
||||
all_ranks = list(range(group.nranks))
|
||||
nvls_connection = group.make_connection(all_ranks, Transport.CudaIpc, use_switch=True)
|
||||
memory1 = GpuBuffer(2**29, cp.int8)
|
||||
@@ -659,7 +658,7 @@ def test_executor(mpi_group: MpiGroup, filename: str):
|
||||
if all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("algo not support cross node")
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
mscclpp_group = CommGroup(mpi_group.comm)
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
npkit_dump_dir = env().npkit_dump_dir
|
||||
if npkit_dump_dir != "":
|
||||
|
||||
Reference in New Issue
Block a user