Add algo->reset after every 1000 calls of alltoallv

This commit is contained in:
Qinghua Zhou
2026-04-09 04:49:46 +00:00
parent 1d271f4cc7
commit 02fdcffde8
3 changed files with 12 additions and 2 deletions

View File

@@ -73,7 +73,7 @@ Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_;
void NativeAlgorithm::reset() {
contexts_.clear();
initialized_ = false;
initialized_ = true;
}
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
@@ -198,4 +198,4 @@ std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
// TODO: implement this
void DslAlgorithm::reset() {}
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -81,6 +81,7 @@ std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
});
self->algo_ = alltoallvAlgo;
return alltoallvAlgo;
}
@@ -240,6 +241,14 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
outputSize, cudaMemcpyDeviceToDevice, stream));
}
static int cnt;
if (cnt++ % 1000 == 0) {
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
if (auto algo = algo_.lock()) {
algo->reset();
}
}
if (cudaGetLastError() == cudaSuccess) {
return CommResult::CommSuccess;
}

View File

@@ -56,6 +56,7 @@ class AlltoallvFullmesh : public AlgorithmBuilder {
std::vector<Connection> conns_;
int worldSize_;
MultiNodeMode multiNodeMode_ = MultiNodeMode::SingleNode;
std::weak_ptr<Algorithm> algo_; // back-ref for calling reset()
};
} // namespace collective