mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 23:37:58 +00:00
[fix]: fix forward cache (maybe)
This commit is contained in:
@@ -302,7 +302,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
)
|
||||
|
||||
# Determine loader and base key format based on method
|
||||
if self.method == "AMXBF16_SFT":
|
||||
if "BF16" in self.method:
|
||||
# BF16 mode: Load from HuggingFace model path
|
||||
loader = BF16SafeTensorLoader(self.weight_path)
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
@@ -323,7 +323,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
# Stack expert weights: [num_experts, ...]
|
||||
# For BF16: weights are already tensors
|
||||
# For SafeTensorLoader: weights might be numpy arrays in nested lists
|
||||
if self.method == "AMXBF16_SFT":
|
||||
if "BF16" in self.method:
|
||||
# BF16SafeTensorLoader returns list of tensors
|
||||
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
|
||||
self.up_proj = torch.stack(up_weights, dim=0).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user