This commit is contained in:
Binyang Li
2026-06-26 00:02:58 +00:00
parent 7b25bd32be
commit cb045249ea

View File

@@ -707,9 +707,9 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
// weighted gather below.
cg::this_grid().sync();
EP_DEVICE_ASSERT(numTopk <= WARP_SIZE and hiddenBf16Int4 <= numThreads);
EP_DEVICE_ASSERT(numTopk <= WARP_SIZE);
static_assert(kHidden % (WARP_SIZE * kNumBf16PerInt4) == 0, "Invalid vectorization");
if (threadId < hiddenBf16Int4) {
for (size_t hiddenIdx = threadId; hiddenIdx < hiddenBf16Int4; hiddenIdx += numThreads) {
for (int tokenIdx = smId; tokenIdx < numCombinedTokens; tokenIdx += numSms) {
int regTopkIdx[kNumMaxTopk];
float regTopkWeights[kNumMaxTopk];
@@ -726,7 +726,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
(regTopkIdx[i] * numMaxDispatchTokensPerRank + tokenIdx) * numBytesPerSlot);
auto stagedRecvRow = reinterpret_cast<const uint8_t*>(stagedRecvType);
auto stagedVec = ld_nc_global(reinterpret_cast<const int4*>(stagedRecvRow) + threadId);
auto stagedVec = ld_nc_global(reinterpret_cast<const int4*>(stagedRecvRow) + hiddenIdx);
const auto stagedValues = reinterpret_cast<nv_bfloat16*>(&stagedVec);
#pragma unroll
for (int j = 0; j < kNumBf16PerInt4; ++j)
@@ -737,7 +737,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
auto combinedOutput = reinterpret_cast<OutputType*>(&combinedInt4);
#pragma unroll
for (int j = 0; j < kNumBf16PerInt4; ++j) combinedOutput[j] = static_cast<OutputType>(combinedValues[j]);
(reinterpret_cast<int4*>(output) + tokenIdx * hiddenBf16Int4)[threadId] = combinedInt4;
(reinterpret_cast<int4*>(output) + tokenIdx * hiddenBf16Int4)[hiddenIdx] = combinedInt4;
}
}
}