use int64_t as expert stride to avoid overflow

This commit is contained in:
Feng Shijie
2025-08-21 06:58:55 +00:00
parent 9fbcc8f8a4
commit 85976b0b87
3 changed files with 19 additions and 18 deletions

View File

@@ -301,7 +301,7 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
{
int up_stride = N / 2 / NLane;
for(int eid = 0; eid < experts_cnt; ++eid)
for(long eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
@@ -319,9 +319,9 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
k2;
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
n1 * KPack + k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}
@@ -330,7 +330,7 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
}
else
{
for(int eid = 0; eid < experts_cnt; ++eid)
for(long eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
@@ -344,9 +344,9 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
k2;
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
n1 * KPack + k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}