Huaiguxu/moe fp8 pertoken scale fix (#2391)

* fix pertoken_scale a_scale dimension

* clang-format

* Fix moe_gemm2_fp8 perTokenScale reference and example.
This commit is contained in:
huaiguxu
2025-06-27 10:24:34 +08:00
committed by GitHub
parent 1749c0409e
commit e1c5172fdb
3 changed files with 26 additions and 9 deletions

View File

@@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else
@@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else