[fix]: fix forward cache (maybe)

This commit is contained in:
mrhaoxx
2026-01-31 21:19:46 +00:00
parent e1e64f7948
commit 9efe1317b1
5 changed files with 21 additions and 14 deletions

View File

@@ -302,7 +302,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
)
# Determine loader and base key format based on method
if self.method == "AMXBF16_SFT":
if "BF16" in self.method:
# BF16 mode: Load from HuggingFace model path
loader = BF16SafeTensorLoader(self.weight_path)
base_key = f"model.layers.{self.layer_idx}"
@@ -323,7 +323,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
# Stack expert weights: [num_experts, ...]
# For BF16: weights are already tensors
# For SafeTensorLoader: weights might be numpy arrays in nested lists
if self.method == "AMXBF16_SFT":
if "BF16" in self.method:
# BF16SafeTensorLoader returns list of tensors
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
self.up_proj = torch.stack(up_weights, dim=0).contiguous()