diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index 35fea8b..6da0a59 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -113,6 +113,8 @@ def routing_dots(bsz, cfg, y, params): activate_all_experts = params.get("activate_all_experts") if activate_all_experts: routing_weights = router_logits.sigmoid() + if cfg.e_score_correction_bias is not None: + routing_weights += cfg.e_score_correction_bias.unsqueeze(0) factor = cfg.routed_scaling_factor / (routing_weights.sum(dim = -1, keepdim = True) + 1e-20) routing_weights *= factor selected_experts = (