mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Add poll() for semaphores (#181)
This commit is contained in:
@@ -21,7 +21,8 @@ struct DeviceSyncer {
|
||||
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
|
||||
/// finished.
|
||||
/// @param blockNum The number of blocks that will synchronize.
|
||||
__forceinline__ __device__ void sync(int blockNum) {
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
__forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) {
|
||||
int maxOldCnt = blockNum - 1;
|
||||
__syncthreads();
|
||||
if (blockNum == 1) return;
|
||||
@@ -33,12 +34,12 @@ struct DeviceSyncer {
|
||||
if (atomicAdd(&count_, 1) == maxOldCnt) {
|
||||
flag_ = 1;
|
||||
}
|
||||
POLL_MAYBE_JAILBREAK(!flag_, 1000000000);
|
||||
POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount);
|
||||
} else {
|
||||
if (atomicSub(&count_, 1) == 1) {
|
||||
flag_ = 0;
|
||||
}
|
||||
POLL_MAYBE_JAILBREAK(flag_, 1000000000);
|
||||
POLL_MAYBE_JAILBREAK(flag_, maxSpinCount);
|
||||
}
|
||||
isAdd_ = tmpIsAdd;
|
||||
}
|
||||
|
||||
@@ -35,8 +35,9 @@ struct FifoDeviceHandle {
|
||||
/// Push a trigger to the FIFO.
|
||||
///
|
||||
/// @param trigger The trigger to push.
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
/// @return The new head of the FIFO.
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger, int64_t maxSpinCount = 1000000) {
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
|
||||
// side
|
||||
@@ -49,7 +50,7 @@ struct FifoDeviceHandle {
|
||||
// condition is not met.
|
||||
if (curFifoHead >= size + *(this->tailReplica)) {
|
||||
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= size + *((volatile uint64_t*)this->tailReplica),
|
||||
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, 1000000);
|
||||
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, maxSpinCount);
|
||||
}
|
||||
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % size]);
|
||||
@@ -60,11 +61,12 @@ struct FifoDeviceHandle {
|
||||
/// Wait until there is a place in the FIFO to push a trigger.
|
||||
///
|
||||
/// @param curFifoHead The current head of the FIFO.
|
||||
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
__forceinline__ __device__ void sync(uint64_t curFifoHead, int64_t maxSpinCount = 1000000) {
|
||||
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
|
||||
// to wait for cudaMemcpy to be done.
|
||||
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % size]) != 0,
|
||||
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
|
||||
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, maxSpinCount);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
|
||||
@@ -58,10 +58,11 @@ union LLPacket {
|
||||
|
||||
/// Read 8 bytes of data from the packet.
|
||||
/// @param flag The flag to read.
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
/// @return The 8-byte data read.
|
||||
__forceinline__ __device__ uint2 read(uint32_t flag) {
|
||||
__forceinline__ __device__ uint2 read(uint32_t flag, int64_t maxSpinCount = 100000000) {
|
||||
uint2 data;
|
||||
POLL_MAYBE_JAILBREAK(readOnce(flag, data), 100000000);
|
||||
POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
|
||||
// If a spin is stuck, escape from it and set status to 1.
|
||||
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
int64_t __spin_cnt = 0; \
|
||||
__status = 0; \
|
||||
while (__cond) { \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
@@ -25,7 +25,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
|
||||
// If a spin is stuck, print a warning and keep spinning.
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
int64_t __spin_cnt = 0; \
|
||||
while (__cond) { \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
@@ -37,7 +37,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
|
||||
// this is specially useful when __cond1 is faster to check
|
||||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
int64_t __spin_cnt = 0; \
|
||||
while (true) { \
|
||||
if (!(__cond1)) { \
|
||||
break; \
|
||||
|
||||
@@ -158,8 +158,13 @@ struct ProxyChannelDeviceHandle {
|
||||
fifo_.sync(curFifoHead);
|
||||
}
|
||||
|
||||
/// Check if the proxy channel has been signaled.
|
||||
/// @return true if the proxy channel has been signaled.
|
||||
__forceinline__ __device__ bool poll() { return semaphore_.poll(); }
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
@@ -217,8 +222,13 @@ struct SimpleProxyChannelDeviceHandle {
|
||||
/// Push a @ref TriggerSync to the FIFO.
|
||||
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
|
||||
|
||||
/// Check if the proxy channel has been signaled.
|
||||
/// @return true if the proxy channel has been signaled.
|
||||
__forceinline__ __device__ bool poll() { return proxyChan_.poll(); }
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { proxyChan_.wait(maxSpinCount); }
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
|
||||
@@ -104,8 +104,13 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
|
||||
/// Signal the remote host.
|
||||
void signal();
|
||||
|
||||
/// Check if the remote host has signaled.
|
||||
/// @return true if the remote host has signaled.
|
||||
bool poll();
|
||||
|
||||
/// Wait for the remote host to signal.
|
||||
void wait();
|
||||
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
|
||||
void wait(int64_t maxSpinCount = 10000000);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
|
||||
@@ -11,10 +11,18 @@ namespace mscclpp {
|
||||
/// Device-side handle for @ref Host2DeviceSemaphore.
|
||||
struct Host2DeviceSemaphoreDeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Poll if the host has signaled.
|
||||
/// @return true if the host has signaled.
|
||||
__forceinline__ __device__ bool poll() {
|
||||
bool signaled = (*(volatile uint64_t*)(inboundSemaphoreId) > (*expectedInboundSemaphoreId));
|
||||
if (signaled) (*expectedInboundSemaphoreId) += 1;
|
||||
return signaled;
|
||||
}
|
||||
|
||||
/// Wait for the host to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 100000000);
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
@@ -25,10 +33,18 @@ struct Host2DeviceSemaphoreDeviceHandle {
|
||||
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
|
||||
struct SmDevice2DeviceSemaphoreDeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Poll if the remote device has signaled.
|
||||
/// @return true if the remote device has signaled.
|
||||
__forceinline__ __device__ bool poll() {
|
||||
bool signaled = ((*inboundSemaphoreId) > (*expectedInboundSemaphoreId));
|
||||
if (signaled) (*expectedInboundSemaphoreId) += 1;
|
||||
return signaled;
|
||||
}
|
||||
|
||||
/// Wait for the remote device to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 100000000);
|
||||
POLL_MAYBE_JAILBREAK((*inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount);
|
||||
}
|
||||
|
||||
/// Signal the remote device.
|
||||
|
||||
@@ -326,8 +326,13 @@ struct SmChannelDeviceHandle {
|
||||
/// Read the counter of the local semaphore.
|
||||
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
|
||||
|
||||
/// Check if the remote semaphore has signaled.
|
||||
/// @return true if the remote semaphore has signaled.
|
||||
__forceinline__ __device__ bool poll() { return semaphore_.poll(); }
|
||||
|
||||
/// Wait for the remote semaphore to send a signal.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.
|
||||
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); }
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
|
||||
@@ -29,7 +29,8 @@ void register_semaphore(nb::module_& m) {
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("wait", &Host2HostSemaphore::wait);
|
||||
.def("poll", &Host2HostSemaphore::poll)
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
|
||||
smDevice2DeviceSemaphore
|
||||
|
||||
@@ -58,9 +58,19 @@ MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
|
||||
*outboundSemaphore_ + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Host2HostSemaphore::wait() {
|
||||
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
|
||||
bool signaled = (*(volatile uint64_t*)localInboundSemaphore_.get() > (*expectedInboundSemaphore_));
|
||||
if (signaled) (*expectedInboundSemaphore_) += 1;
|
||||
return signaled;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) {
|
||||
(*expectedInboundSemaphore_) += 1;
|
||||
int64_t spinCount = 0;
|
||||
while (*(volatile uint64_t*)localInboundSemaphore_.get() < (*expectedInboundSemaphore_)) {
|
||||
if (spinCount++ == maxSpinCount) {
|
||||
throw Error("Host2HostSemaphore::wait timed out", ErrorCode::Timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase {
|
||||
|
||||
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIbOnly, void* sendBuff,
|
||||
size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0);
|
||||
void testPingPong(bool useIbOnly, bool waitWithPoll);
|
||||
void testPacketPingPong(bool useIbOnly);
|
||||
void testPacketPingPongPerf(bool useIbOnly);
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
|
||||
|
||||
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> gChannelOneToOneTestConstProxyChans;
|
||||
|
||||
__global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
__global__ void kernelProxyPingPong(int* buff, int rank, int nElem, bool waitWithPoll, int* ret) {
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = gChannelOneToOneTestConstProxyChans;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
int nTries = 1000;
|
||||
@@ -76,7 +76,20 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
for (int i = 0; i < nTries; i++) {
|
||||
if (rank == 0) {
|
||||
if (i > 0) {
|
||||
if (threadIdx.x == 0) proxyChan.wait();
|
||||
if (threadIdx.x == 0) {
|
||||
if (waitWithPoll) {
|
||||
int spin = 1000000;
|
||||
while (!proxyChan.poll() && spin > 0) {
|
||||
spin--;
|
||||
}
|
||||
if (spin == 0) {
|
||||
// printf("rank 0 ERROR: poll timeout\n");
|
||||
*ret = 1;
|
||||
}
|
||||
} else {
|
||||
proxyChan.wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
|
||||
if (sendBuff[j] != rank1Offset + i - 1 + j) {
|
||||
@@ -94,7 +107,20 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
if (threadIdx.x == 0) proxyChan.putWithSignal(0, nElem * sizeof(int));
|
||||
}
|
||||
if (rank == 1) {
|
||||
if (threadIdx.x == 0) proxyChan.wait();
|
||||
if (threadIdx.x == 0) {
|
||||
if (waitWithPoll) {
|
||||
int spin = 1000000;
|
||||
while (!proxyChan.poll() && spin > 0) {
|
||||
spin--;
|
||||
}
|
||||
if (spin == 0) {
|
||||
// printf("rank 0 ERROR: poll timeout\n");
|
||||
*ret = 1;
|
||||
}
|
||||
} else {
|
||||
proxyChan.wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
|
||||
if (sendBuff[j] != i + j) {
|
||||
@@ -120,14 +146,14 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
|
||||
setupMeshConnections(proxyChannels, true, buff.get(), nElem * sizeof(int));
|
||||
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
|
||||
@@ -140,22 +166,22 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, waitWithPoll, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, waitWithPoll, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, waitWithPoll, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
@@ -163,6 +189,14 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPong) { testPingPong(false, false); }
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIb) { testPingPong(true, false); }
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { testPingPong(false, true); }
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { testPingPong(true, true); }
|
||||
|
||||
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
|
||||
|
||||
template <bool CheckCorrectness>
|
||||
|
||||
Reference in New Issue
Block a user