mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add algo->reset after every 1000 calls of alltoallv
This commit is contained in:
@@ -73,7 +73,7 @@ Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_;
|
||||
|
||||
void NativeAlgorithm::reset() {
|
||||
contexts_.clear();
|
||||
initialized_ = false;
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
|
||||
@@ -198,4 +198,4 @@ std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
|
||||
// TODO: implement this
|
||||
void DslAlgorithm::reset() {}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -81,6 +81,7 @@ std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
|
||||
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
|
||||
});
|
||||
|
||||
self->algo_ = alltoallvAlgo;
|
||||
return alltoallvAlgo;
|
||||
}
|
||||
|
||||
@@ -240,6 +241,14 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
outputSize, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
static int cnt;
|
||||
if (cnt++ % 1000 == 0) {
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
|
||||
if (auto algo = algo_.lock()) {
|
||||
algo->reset();
|
||||
}
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
return CommResult::CommSuccess;
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ class AlltoallvFullmesh : public AlgorithmBuilder {
|
||||
std::vector<Connection> conns_;
|
||||
int worldSize_;
|
||||
MultiNodeMode multiNodeMode_ = MultiNodeMode::SingleNode;
|
||||
std::weak_ptr<Algorithm> algo_; // back-ref for calling reset()
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
|
||||
Reference in New Issue
Block a user