diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py index 6f583ae7..3923b91b 100644 --- a/python/mscclpp/ext/alltoallv_single.py +++ b/python/mscclpp/ext/alltoallv_single.py @@ -17,11 +17,11 @@ import torch import torch.distributed as dist from typing import Optional, List, Tuple from mscclpp._mscclpp import ( - Communicator, - TcpBootstrap, - DataType, - ReduceOp, - CommResult, + CppCommunicator as Communicator, + CppTcpBootstrap as TcpBootstrap, + CppDataType as DataType, + CppReduceOp as ReduceOp, + CppCommResult as CommResult, ) from mscclpp.ext.algorithm_collection_builder import AlgorithmCollectionBuilder @@ -323,6 +323,7 @@ class MscclppAlltoAllV: None, # executor (not needed for native algos) 0, # nblocks (auto) 0, # nthreads_per_block (auto) + False, # symmetric_memory self._extras, ) diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index e8797e43..d45fb6f4 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -130,11 +130,11 @@ def main(): print("=" * 60) # Import after torch.distributed init - from mscclpp._mscclpp import ( + from mscclpp import ( Communicator, TcpBootstrap, - UniqueId, ) + from mscclpp._mscclpp import CppUniqueId as UniqueId from mscclpp.ext.alltoallv_single import MscclppAlltoAllV # Create mscclpp communicator with TcpBootstrap diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index 4a57d30d..8d1fae83 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -67,7 +67,8 @@ std::shared_ptr AlltoallvFullmesh::build() { [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) { + const std::unordered_map& extras, + [[maybe_unused]] DataType accumDtype) -> CommResult { return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, nBlocks, nThreadsPerBlock, extras); }, @@ -77,7 +78,8 @@ std::shared_ptr AlltoallvFullmesh::build() { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); }, // Context key generation function - [self](const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype) { + [self](const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, + [[maybe_unused]] bool symmetricMemory) { return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype); });