From cb045249eabefe72ee89a617d06214fcba97fa07 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 26 Jun 2026 00:02:58 +0000 Subject: [PATCH] fix --- src/ext/ep/kernels/low_latency.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ext/ep/kernels/low_latency.cu b/src/ext/ep/kernels/low_latency.cu index b7957643..e4306222 100644 --- a/src/ext/ep/kernels/low_latency.cu +++ b/src/ext/ep/kernels/low_latency.cu @@ -707,9 +707,9 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t* // weighted gather below. cg::this_grid().sync(); - EP_DEVICE_ASSERT(numTopk <= WARP_SIZE and hiddenBf16Int4 <= numThreads); + EP_DEVICE_ASSERT(numTopk <= WARP_SIZE); static_assert(kHidden % (WARP_SIZE * kNumBf16PerInt4) == 0, "Invalid vectorization"); - if (threadId < hiddenBf16Int4) { + for (size_t hiddenIdx = threadId; hiddenIdx < hiddenBf16Int4; hiddenIdx += numThreads) { for (int tokenIdx = smId; tokenIdx < numCombinedTokens; tokenIdx += numSms) { int regTopkIdx[kNumMaxTopk]; float regTopkWeights[kNumMaxTopk]; @@ -726,7 +726,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t* (regTopkIdx[i] * numMaxDispatchTokensPerRank + tokenIdx) * numBytesPerSlot); auto stagedRecvRow = reinterpret_cast(stagedRecvType); - auto stagedVec = ld_nc_global(reinterpret_cast(stagedRecvRow) + threadId); + auto stagedVec = ld_nc_global(reinterpret_cast(stagedRecvRow) + hiddenIdx); const auto stagedValues = reinterpret_cast(&stagedVec); #pragma unroll for (int j = 0; j < kNumBf16PerInt4; ++j) @@ -737,7 +737,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t* auto combinedOutput = reinterpret_cast(&combinedInt4); #pragma unroll for (int j = 0; j < kNumBf16PerInt4; ++j) combinedOutput[j] = static_cast(combinedValues[j]); - (reinterpret_cast(output) + tokenIdx * hiddenBf16Int4)[threadId] = combinedInt4; + (reinterpret_cast(output) + tokenIdx * hiddenBf16Int4)[hiddenIdx] = combinedInt4; } } }