From 02fdcffde8c97834da729aab215ac3f1f090e73a Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Thu, 9 Apr 2026 04:49:46 +0000 Subject: [PATCH] Add algo->reset after every 1000 calls of alltoallv --- src/core/algorithm.cc | 4 ++-- src/ext/collectives/alltoallv/alltoallv_fullmesh.cu | 9 +++++++++ .../collectives/include/alltoallv/alltoallv_fullmesh.hpp | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/core/algorithm.cc b/src/core/algorithm.cc index 31c98f15..07da9045 100644 --- a/src/core/algorithm.cc +++ b/src/core/algorithm.cc @@ -73,7 +73,7 @@ Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; void NativeAlgorithm::reset() { contexts_.clear(); - initialized_ = false; + initialized_ = true; } void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName, @@ -198,4 +198,4 @@ std::shared_ptr DslAlgorithm::build() { return shared_from_this(); } // TODO: implement this void DslAlgorithm::reset() {} -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index 4a57d30d..db3cbfd7 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -81,6 +81,7 @@ std::shared_ptr AlltoallvFullmesh::build() { return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype); }); + self->algo_ = alltoallvAlgo; return alltoallvAlgo; } @@ -240,6 +241,14 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( outputSize, cudaMemcpyDeviceToDevice, stream)); } + static int cnt; + if (cnt++ % 1000 == 0) { + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + if (auto algo = algo_.lock()) { + algo->reset(); + } + } + if (cudaGetLastError() == cudaSuccess) { return CommResult::CommSuccess; } diff --git a/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp b/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp index 22c1cf72..5ba9753a 100644 --- a/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp +++ b/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp @@ -56,6 +56,7 @@ class AlltoallvFullmesh : public AlgorithmBuilder { std::vector conns_; int worldSize_; MultiNodeMode multiNodeMode_ = MultiNodeMode::SingleNode; + std::weak_ptr algo_; // back-ref for calling reset() }; } // namespace collective