mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Fix device assert (#522)
* Fixed a bug that external `assert()`s may not be compiled with mscclpp headers * Use a macro assert instead of a function
This commit is contained in:
@@ -12,11 +12,7 @@
|
||||
|
||||
#if !defined(DEBUG_BUILD)
|
||||
|
||||
#define __assert_fail(__assertion, __file, __line, __function) ;
|
||||
|
||||
namespace mscclpp {
|
||||
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char* msg) {}
|
||||
} // namespace mscclpp
|
||||
#define mscclpp_assert_device(__cond, __msg)
|
||||
|
||||
#else // defined(DEBUG_BUILD)
|
||||
|
||||
@@ -28,13 +24,12 @@ extern "C" __host__ __device__ void __assert_fail(const char *__assertion, const
|
||||
const char *__function) __THROW;
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
|
||||
namespace mscclpp {
|
||||
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char *msg) {
|
||||
if (!cond) {
|
||||
__assert_fail(msg, __FILE__, __LINE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
}
|
||||
} // namespace mscclpp
|
||||
#define mscclpp_assert_device(__cond, __msg) \
|
||||
do { \
|
||||
if (!(__cond)) { \
|
||||
__assert_fail(__msg, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif // !defined(DEBUG_BUILD)
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle {
|
||||
///
|
||||
template <typename PacketType = LL16Packet>
|
||||
MSCCLPP_DEVICE_INLINE auto unpackPacket(uint64_t index, uint32_t flag, int64_t maxSpinCount = -1) {
|
||||
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
|
||||
mscclpp_assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
|
||||
return reinterpret_cast<PacketType*>(packetBuffer_)[index].read(flag, maxSpinCount);
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle {
|
||||
int64_t maxSpinCount = -1) {
|
||||
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
|
||||
"Unsupported packet type");
|
||||
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
|
||||
mscclpp_assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
|
||||
copyFromPackets<PacketType>(reinterpret_cast<char*>(src_) + originOffset,
|
||||
reinterpret_cast<char*>(packetBuffer_) + targetOffset, originBytes, threadId,
|
||||
numThreads, flag, maxSpinCount);
|
||||
|
||||
@@ -9,31 +9,27 @@
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
// If a spin is stuck, print a warning and keep spinning.
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
do { \
|
||||
int64_t __spin_cnt = 0; \
|
||||
while (__cond) { \
|
||||
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
|
||||
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
} \
|
||||
} \
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
do { \
|
||||
[[maybe_unused]] int64_t __spin_cnt = 0; \
|
||||
while (__cond) { \
|
||||
mscclpp_assert_device((__max_spin_cnt < 0 || __spin_cnt++ != __max_spin_cnt), #__cond); \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
|
||||
// this is specially useful when __cond1 is faster to check
|
||||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
|
||||
do { \
|
||||
int64_t __spin_cnt = 0; \
|
||||
while (true) { \
|
||||
if (!(__cond1)) { \
|
||||
break; \
|
||||
} else if (!(__cond2)) { \
|
||||
break; \
|
||||
} \
|
||||
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
|
||||
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
} \
|
||||
} \
|
||||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
|
||||
do { \
|
||||
[[maybe_unused]] int64_t __spin_cnt = 0; \
|
||||
while (true) { \
|
||||
if (!(__cond1)) { \
|
||||
break; \
|
||||
} else if (!(__cond2)) { \
|
||||
break; \
|
||||
} \
|
||||
mscclpp_assert_device((__max_spin_cnt < 0 || __spin_cnt++ != __max_spin_cnt), #__cond1 #__cond2); \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
@@ -59,8 +59,6 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
|
||||
if (dev_ptr[idx] != ((nranks * (nranks - 1)) / 2)) {
|
||||
__assert_fail("dev_ptr[idx] != nranks", __FILE__, __LINE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
mscclpp_assert_device(dev_ptr[idx] == ((nranks * (nranks - 1)) / 2), "dev_ptr[idx] != nranks");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user