diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index ecd13c47..f7ec67d0 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -165,6 +165,7 @@ void gpuFreePhysical(void* ptr); void gpuMemcpyAsync(void* dst, const void* src, size_t bytes, cudaStream_t stream, cudaMemcpyKind kind = cudaMemcpyDefault); void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind = cudaMemcpyDefault); +void gpuMemset(void* ptr, int value, size_t bytes); /// A template function that allocates memory while ensuring that the memory will be freed when the returned object is /// destroyed. @@ -300,6 +301,8 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind); } +inline void memset(void* ptr, int value, size_t bytes) { detail::gpuMemset(ptr, value, bytes); } + /// Check if NVLink SHARP (NVLS) is supported. /// /// @return True if NVLink SHARP (NVLS) is supported, false otherwise. diff --git a/src/core/gpu_utils.cc b/src/core/gpu_utils.cc index 09d5025d..1ce61322 100644 --- a/src/core/gpu_utils.cc +++ b/src/core/gpu_utils.cc @@ -267,6 +267,13 @@ void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind) { MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); } +void gpuMemset(void* ptr, int value, size_t bytes) { + AvoidCudaGraphCaptureGuard cgcGuard; + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, value, bytes, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); +} + } // namespace detail bool isNvlsSupported() { diff --git a/src/core/utils_internal.cc b/src/core/utils_internal.cc index 9504a52c..ea867fff 100644 --- a/src/core/utils_internal.cc +++ b/src/core/utils_internal.cc @@ -263,8 +263,10 @@ std::shared_ptr TokenPool::getToken() { for (int bit = 0; bit < UINT64_WIDTH; bit++) { if (holes & (1UL << bit)) { allocationMap_[i].set(bit); - INFO(MSCCLPP_ALLOC, "TokenPool allocated token at addr %p", baseAddr_ + i * UINT64_WIDTH + bit); - return std::shared_ptr(baseAddr_ + i * UINT64_WIDTH + bit, deleter); + uint64_t* token = baseAddr_ + i * UINT64_WIDTH + bit; + mscclpp::memset(token, 0, sizeof(uint64_t)); + INFO(MSCCLPP_ALLOC, "TokenPool allocated token at addr %p", token); + return std::shared_ptr(token, deleter); } } }