Add poll() for semaphores (#181)

This commit is contained in:
Changho Hwang
2023-09-15 15:40:44 +08:00
committed by GitHub
parent d2f13f1e54
commit 3aa72098d9
12 changed files with 117 additions and 31 deletions

View File

@@ -21,7 +21,8 @@ struct DeviceSyncer {
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
/// finished.
/// @param blockNum The number of blocks that will synchronize.
__forceinline__ __device__ void sync(int blockNum) {
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) {
int maxOldCnt = blockNum - 1;
__syncthreads();
if (blockNum == 1) return;
@@ -33,12 +34,12 @@ struct DeviceSyncer {
if (atomicAdd(&count_, 1) == maxOldCnt) {
flag_ = 1;
}
POLL_MAYBE_JAILBREAK(!flag_, 1000000000);
POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount);
} else {
if (atomicSub(&count_, 1) == 1) {
flag_ = 0;
}
POLL_MAYBE_JAILBREAK(flag_, 1000000000);
POLL_MAYBE_JAILBREAK(flag_, maxSpinCount);
}
isAdd_ = tmpIsAdd;
}

View File

@@ -35,8 +35,9 @@ struct FifoDeviceHandle {
/// Push a trigger to the FIFO.
///
/// @param trigger The trigger to push.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
/// @return The new head of the FIFO.
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger, int64_t maxSpinCount = 1000000) {
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
// side
@@ -49,7 +50,7 @@ struct FifoDeviceHandle {
// condition is not met.
if (curFifoHead >= size + *(this->tailReplica)) {
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= size + *((volatile uint64_t*)this->tailReplica),
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, 1000000);
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, maxSpinCount);
}
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % size]);
@@ -60,11 +61,12 @@ struct FifoDeviceHandle {
/// Wait until there is a place in the FIFO to push a trigger.
///
/// @param curFifoHead The current head of the FIFO.
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void sync(uint64_t curFifoHead, int64_t maxSpinCount = 1000000) {
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
// to wait for cudaMemcpy to be done.
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % size]) != 0,
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, maxSpinCount);
}
#endif // __CUDACC__

View File

@@ -58,10 +58,11 @@ union LLPacket {
/// Read 8 bytes of data from the packet.
/// @param flag The flag to read.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
/// @return The 8-byte data read.
__forceinline__ __device__ uint2 read(uint32_t flag) {
__forceinline__ __device__ uint2 read(uint32_t flag, int64_t maxSpinCount = 100000000) {
uint2 data;
POLL_MAYBE_JAILBREAK(readOnce(flag, data), 100000000);
POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount);
return data;
}

View File

@@ -12,7 +12,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
// If a spin is stuck, escape from it and set status to 1.
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
do { \
uint64_t __spin_cnt = 0; \
int64_t __spin_cnt = 0; \
__status = 0; \
while (__cond) { \
if (__spin_cnt++ == __max_spin_cnt) { \
@@ -25,7 +25,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
int64_t __spin_cnt = 0; \
while (__cond) { \
if (__spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
@@ -37,7 +37,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
// this is specially useful when __cond1 is faster to check
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
int64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \

View File

@@ -158,8 +158,13 @@ struct ProxyChannelDeviceHandle {
fifo_.sync(curFifoHead);
}
/// Check if the proxy channel has been signaled.
/// @return true if the proxy channel has been signaled.
__forceinline__ __device__ bool poll() { return semaphore_.poll(); }
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); }
#endif // __CUDACC__
};
@@ -217,8 +222,13 @@ struct SimpleProxyChannelDeviceHandle {
/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
/// Check if the proxy channel has been signaled.
/// @return true if the proxy channel has been signaled.
__forceinline__ __device__ bool poll() { return proxyChan_.poll(); }
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { proxyChan_.wait(maxSpinCount); }
#endif // __CUDACC__
};

View File

@@ -104,8 +104,13 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
/// Signal the remote host.
void signal();
/// Check if the remote host has signaled.
/// @return true if the remote host has signaled.
bool poll();
/// Wait for the remote host to signal.
void wait();
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
void wait(int64_t maxSpinCount = 10000000);
private:
std::shared_ptr<Connection> connection_;

View File

@@ -11,10 +11,18 @@ namespace mscclpp {
/// Device-side handle for @ref Host2DeviceSemaphore.
struct Host2DeviceSemaphoreDeviceHandle {
#ifdef __CUDACC__
/// Poll if the host has signaled.
/// @return true if the host has signaled.
__forceinline__ __device__ bool poll() {
bool signaled = (*(volatile uint64_t*)(inboundSemaphoreId) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
return signaled;
}
/// Wait for the host to signal.
__forceinline__ __device__ void wait() {
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 100000000);
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount);
}
#endif // __CUDACC__
@@ -25,10 +33,18 @@ struct Host2DeviceSemaphoreDeviceHandle {
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
struct SmDevice2DeviceSemaphoreDeviceHandle {
#ifdef __CUDACC__
/// Poll if the remote device has signaled.
/// @return true if the remote device has signaled.
__forceinline__ __device__ bool poll() {
bool signaled = ((*inboundSemaphoreId) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
return signaled;
}
/// Wait for the remote device to signal.
__forceinline__ __device__ void wait() {
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 100000000);
POLL_MAYBE_JAILBREAK((*inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount);
}
/// Signal the remote device.

View File

@@ -326,8 +326,13 @@ struct SmChannelDeviceHandle {
/// Read the counter of the local semaphore.
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
/// Check if the remote semaphore has signaled.
/// @return true if the remote semaphore has signaled.
__forceinline__ __device__ bool poll() { return semaphore_.poll(); }
/// Wait for the remote semaphore to send a signal.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); }
#endif // __CUDACC__
};

View File

@@ -29,7 +29,8 @@ void register_semaphore(nb::module_& m) {
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2HostSemaphore::connection)
.def("signal", &Host2HostSemaphore::signal)
.def("wait", &Host2HostSemaphore::wait);
.def("poll", &Host2HostSemaphore::poll)
.def("wait", &Host2HostSemaphore::wait, nb::arg("max_spin_count") = 10000000);
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
smDevice2DeviceSemaphore

View File

@@ -58,9 +58,19 @@ MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
*outboundSemaphore_ + 1);
}
MSCCLPP_API_CPP void Host2HostSemaphore::wait() {
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
bool signaled = (*(volatile uint64_t*)localInboundSemaphore_.get() > (*expectedInboundSemaphore_));
if (signaled) (*expectedInboundSemaphore_) += 1;
return signaled;
}
MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) {
(*expectedInboundSemaphore_) += 1;
int64_t spinCount = 0;
while (*(volatile uint64_t*)localInboundSemaphore_.get() < (*expectedInboundSemaphore_)) {
if (spinCount++ == maxSpinCount) {
throw Error("Host2HostSemaphore::wait timed out", ErrorCode::Timeout);
}
}
}

View File

@@ -134,6 +134,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase {
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIbOnly, void* sendBuff,
size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0);
void testPingPong(bool useIbOnly, bool waitWithPoll);
void testPacketPingPong(bool useIbOnly);
void testPacketPingPongPerf(bool useIbOnly);

View File

@@ -67,7 +67,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> gChannelOneToOneTestConstProxyChans;
__global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
__global__ void kernelProxyPingPong(int* buff, int rank, int nElem, bool waitWithPoll, int* ret) {
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = gChannelOneToOneTestConstProxyChans;
volatile int* sendBuff = (volatile int*)buff;
int nTries = 1000;
@@ -76,7 +76,20 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
for (int i = 0; i < nTries; i++) {
if (rank == 0) {
if (i > 0) {
if (threadIdx.x == 0) proxyChan.wait();
if (threadIdx.x == 0) {
if (waitWithPoll) {
int spin = 1000000;
while (!proxyChan.poll() && spin > 0) {
spin--;
}
if (spin == 0) {
// printf("rank 0 ERROR: poll timeout\n");
*ret = 1;
}
} else {
proxyChan.wait();
}
}
__syncthreads();
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
if (sendBuff[j] != rank1Offset + i - 1 + j) {
@@ -94,7 +107,20 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
if (threadIdx.x == 0) proxyChan.putWithSignal(0, nElem * sizeof(int));
}
if (rank == 1) {
if (threadIdx.x == 0) proxyChan.wait();
if (threadIdx.x == 0) {
if (waitWithPoll) {
int spin = 1000000;
while (!proxyChan.poll() && spin > 0) {
spin--;
}
if (spin == 0) {
// printf("rank 0 ERROR: poll timeout\n");
*ret = 1;
}
} else {
proxyChan.wait();
}
}
__syncthreads();
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
if (sendBuff[j] != i + j) {
@@ -120,14 +146,14 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) {
}
}
TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
if (gEnv->rank >= numRanksToUse) return;
const int nElem = 4 * 1024 * 1024;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
setupMeshConnections(proxyChannels, true, buff.get(), nElem * sizeof(int));
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
@@ -140,22 +166,22 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, waitWithPoll, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, waitWithPoll, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, waitWithPoll, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
@@ -163,6 +189,14 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
proxyService->stopProxy();
}
TEST_F(ProxyChannelOneToOneTest, PingPong) { testPingPong(false, false); }
TEST_F(ProxyChannelOneToOneTest, PingPongIb) { testPingPong(true, false); }
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { testPingPong(false, true); }
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { testPingPong(true, true); }
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
template <bool CheckCorrectness>