Fix device assert (#522)

* Fixed a bug that external `assert()`s may not be compiled with mscclpp
headers
* Use a macro assert instead of a function
This commit is contained in:
Changho Hwang
2025-05-12 13:38:11 -07:00
committed by GitHub
parent a464b9f21e
commit 5205618c4a
4 changed files with 27 additions and 38 deletions

View File

@@ -12,11 +12,7 @@
#if !defined(DEBUG_BUILD)
#define __assert_fail(__assertion, __file, __line, __function) ;
namespace mscclpp {
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char* msg) {}
} // namespace mscclpp
#define mscclpp_assert_device(__cond, __msg)
#else // defined(DEBUG_BUILD)
@@ -28,13 +24,12 @@ extern "C" __host__ __device__ void __assert_fail(const char *__assertion, const
const char *__function) __THROW;
#endif // !defined(MSCCLPP_DEVICE_HIP)
namespace mscclpp {
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char *msg) {
if (!cond) {
__assert_fail(msg, __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
}
} // namespace mscclpp
#define mscclpp_assert_device(__cond, __msg) \
do { \
if (!(__cond)) { \
__assert_fail(__msg, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} while (0)
#endif // !defined(DEBUG_BUILD)

View File

@@ -183,7 +183,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle {
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE auto unpackPacket(uint64_t index, uint32_t flag, int64_t maxSpinCount = -1) {
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
mscclpp_assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
return reinterpret_cast<PacketType*>(packetBuffer_)[index].read(flag, maxSpinCount);
}
@@ -207,7 +207,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle {
int64_t maxSpinCount = -1) {
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
"Unsupported packet type");
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
mscclpp_assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
copyFromPackets<PacketType>(reinterpret_cast<char*>(src_) + originOffset,
reinterpret_cast<char*>(packetBuffer_) + targetOffset, originBytes, threadId,
numThreads, flag, maxSpinCount);

View File

@@ -9,31 +9,27 @@
#if defined(MSCCLPP_DEVICE_COMPILE)
// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (__cond) { \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
[[maybe_unused]] int64_t __spin_cnt = 0; \
while (__cond) { \
mscclpp_assert_device((__max_spin_cnt < 0 || __spin_cnt++ != __max_spin_cnt), #__cond); \
} \
} while (0);
// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
// this is specially useful when __cond1 is faster to check
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
[[maybe_unused]] int64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
mscclpp_assert_device((__max_spin_cnt < 0 || __spin_cnt++ != __max_spin_cnt), #__cond1 #__cond2); \
} \
} while (0);
#endif // defined(MSCCLPP_DEVICE_COMPILE)

View File

@@ -59,8 +59,6 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
deviceSyncer.sync(gridDim.x);
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
if (dev_ptr[idx] != ((nranks * (nranks - 1)) / 2)) {
__assert_fail("dev_ptr[idx] != nranks", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
mscclpp_assert_device(dev_ptr[idx] == ((nranks * (nranks - 1)) / 2), "dev_ptr[idx] != nranks");
}
}