mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-29 10:57:27 +00:00
fix
This commit is contained in:
@@ -707,9 +707,9 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
|
||||
// weighted gather below.
|
||||
cg::this_grid().sync();
|
||||
|
||||
EP_DEVICE_ASSERT(numTopk <= WARP_SIZE and hiddenBf16Int4 <= numThreads);
|
||||
EP_DEVICE_ASSERT(numTopk <= WARP_SIZE);
|
||||
static_assert(kHidden % (WARP_SIZE * kNumBf16PerInt4) == 0, "Invalid vectorization");
|
||||
if (threadId < hiddenBf16Int4) {
|
||||
for (size_t hiddenIdx = threadId; hiddenIdx < hiddenBf16Int4; hiddenIdx += numThreads) {
|
||||
for (int tokenIdx = smId; tokenIdx < numCombinedTokens; tokenIdx += numSms) {
|
||||
int regTopkIdx[kNumMaxTopk];
|
||||
float regTopkWeights[kNumMaxTopk];
|
||||
@@ -726,7 +726,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
|
||||
(regTopkIdx[i] * numMaxDispatchTokensPerRank + tokenIdx) * numBytesPerSlot);
|
||||
auto stagedRecvRow = reinterpret_cast<const uint8_t*>(stagedRecvType);
|
||||
|
||||
auto stagedVec = ld_nc_global(reinterpret_cast<const int4*>(stagedRecvRow) + threadId);
|
||||
auto stagedVec = ld_nc_global(reinterpret_cast<const int4*>(stagedRecvRow) + hiddenIdx);
|
||||
const auto stagedValues = reinterpret_cast<nv_bfloat16*>(&stagedVec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumBf16PerInt4; ++j)
|
||||
@@ -737,7 +737,7 @@ MSCCLPP_DEVICE_INLINE void combineRecv(void* output, void* stagedRecv, int64_t*
|
||||
auto combinedOutput = reinterpret_cast<OutputType*>(&combinedInt4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumBf16PerInt4; ++j) combinedOutput[j] = static_cast<OutputType>(combinedValues[j]);
|
||||
(reinterpret_cast<int4*>(output) + tokenIdx * hiddenBf16Int4)[threadId] = combinedInt4;
|
||||
(reinterpret_cast<int4*>(output) + tokenIdx * hiddenBf16Int4)[hiddenIdx] = combinedInt4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user