change to use atomicFetchAdd instead of add + atomicStore for d2d signal

This commit is contained in:
Binyang Li
2026-02-19 20:09:27 +00:00
parent 4701ae3a95
commit 6db7785623
4 changed files with 6 additions and 26 deletions

View File

@@ -82,7 +82,6 @@ class MemoryDevice2DeviceSemaphore {
private:
Semaphore semaphore_;
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
detail::UniqueGpuPtr<uint64_t> outboundToken_;
public:
/// Constructor.

View File

@@ -82,19 +82,17 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
/// Signal remote device, ensures prior memory ops complete.
MSCCLPP_DEVICE_INLINE void signal() {
auto outbound = incOutbound();
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
// Using memoryOrderSeqCst is faster for A100.
atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst);
atomicFetchAdd(remoteInboundToken, 1UL, memoryOrderSeqCst);
#else
atomicStore(remoteInboundToken, outbound, memoryOrderRelease);
atomicFetchAdd(remoteInboundToken, 1UL, memoryOrderRelease);
#endif
}
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
auto outbound = incOutbound();
atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed);
atomicFetchAdd(remoteInboundToken, 1UL, memoryOrderRelaxed);
}
/// Thread-safe read of expected inbound value.
@@ -121,27 +119,13 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderRelaxed);
}
/// Thread-safe read of outbound value.
/// @return The outbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
return atomicLoad<uint64_t, scopeDevice>(outboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of outbound value.
/// @return The incremented outbound value.
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(outboundToken, 1, memoryOrderRelaxed) + 1;
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where the remote device will write its semaphore value and the local device will read it.
uint64_t* inboundToken;
/// A local memory space where the local device stores the semaphore value to be written to the remote device.
uint64_t* outboundToken;
/// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the
/// remote device.
/// A remote memory space where the local device writes to signal the remote device. This points to the
/// inboundToken of the remote device.
uint64_t* remoteInboundToken;
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.

View File

@@ -43,7 +43,6 @@ void register_semaphore(nb::module_& m) {
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {

View File

@@ -184,8 +184,7 @@ MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) {
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const Semaphore& semaphore)
: semaphore_(semaphore),
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
outboundToken_(detail::gpuCallocUnique<uint64_t>()) {
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()) {
if (connection().localDevice().type != DeviceType::GPU) {
throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
}
@@ -202,7 +201,6 @@ MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSe
device.remoteInboundToken = reinterpret_cast<uint64_t*>(semaphore_.remoteMemory().data());
device.inboundToken = reinterpret_cast<uint64_t*>(semaphore_.localMemory().data());
device.expectedInboundToken = expectedInboundToken_.get();
device.outboundToken = outboundToken_.get();
return device;
};