mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 15:30:05 +00:00
[fix]: prequant weight load
This commit is contained in:
@@ -219,6 +219,13 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
config.gate_scales = self._gate_scale_ptrs
|
||||
config.up_scales = self._up_scale_ptrs
|
||||
config.down_scales = self._down_scale_ptrs
|
||||
# Also provide BF16 weight pointers for backward gradient computation.
|
||||
# C++ backward needs BF16 base weights to compute gate/up LoRA B gradients
|
||||
# through the gated MLP chain (grad_hidden = down_proj^T @ grad_output).
|
||||
if getattr(self, "_bf16_gate_proj", None) is not None:
|
||||
config.gate_proj = self._bf16_gate_proj.data_ptr()
|
||||
config.up_proj = self._bf16_up_proj.data_ptr()
|
||||
config.down_proj = self._bf16_down_proj.data_ptr()
|
||||
else:
|
||||
# Flat BF16 buffer path
|
||||
config.gate_proj = self.gate_proj.data_ptr()
|
||||
|
||||
Reference in New Issue
Block a user