[fix]: prequant weight load

This commit is contained in:
mrhaoxx
2026-02-01 15:17:47 +00:00
parent 9efe1317b1
commit 06fb3b5dbf
2 changed files with 64 additions and 1 deletions

View File

@@ -219,6 +219,13 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
config.gate_scales = self._gate_scale_ptrs
config.up_scales = self._up_scale_ptrs
config.down_scales = self._down_scale_ptrs
# Also provide BF16 weight pointers for backward gradient computation.
# C++ backward needs BF16 base weights to compute gate/up LoRA B gradients
# through the gated MLP chain (grad_hidden = down_proj^T @ grad_output).
if getattr(self, "_bf16_gate_proj", None) is not None:
config.gate_proj = self._bf16_gate_proj.data_ptr()
config.up_proj = self._bf16_up_proj.data_ptr()
config.down_proj = self._bf16_down_proj.data_ptr()
else:
# Flat BF16 buffer path
config.gate_proj = self.gate_proj.data_ptr()